"Fossies" - the Fresh Open Source Software Archive

Member "keystone-19.0.0/keystone/tests/unit/test_backend_sql.py" (14 Apr 2021, 62207 Bytes) of package /linux/misc/openstack/keystone-19.0.0.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. See also the latest Fossies "Diffs" side-by-side code changes report for "test_backend_sql.py": 18.0.0_vs_19.0.0.

    1 # Copyright 2012 OpenStack Foundation
    2 #
    3 # Licensed under the Apache License, Version 2.0 (the "License"); you may
    4 # not use this file except in compliance with the License. You may obtain
    5 # a copy of the License at
    6 #
    7 #      http://www.apache.org/licenses/LICENSE-2.0
    8 #
    9 # Unless required by applicable law or agreed to in writing, software
   10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   12 # License for the specific language governing permissions and limitations
   13 # under the License.
   14 
   15 import datetime
   16 from unittest import mock
   17 import uuid
   18 
   19 import fixtures
   20 import freezegun
   21 from oslo_db import exception as db_exception
   22 from oslo_db import options
   23 from oslo_log import log
   24 import sqlalchemy
   25 from sqlalchemy import exc
   26 from testtools import matchers
   27 
   28 from keystone.common import driver_hints
   29 from keystone.common import provider_api
   30 from keystone.common import sql
   31 from keystone.common.sql import core
   32 import keystone.conf
   33 from keystone.credential.providers import fernet as credential_provider
   34 from keystone import exception
   35 from keystone.identity.backends import sql_model as identity_sql
   36 from keystone.resource.backends import base as resource
   37 from keystone.tests import unit
   38 from keystone.tests.unit.assignment import test_backends as assignment_tests
   39 from keystone.tests.unit.catalog import test_backends as catalog_tests
   40 from keystone.tests.unit import default_fixtures
   41 from keystone.tests.unit.identity import test_backends as identity_tests
   42 from keystone.tests.unit import ksfixtures
   43 from keystone.tests.unit.ksfixtures import database
   44 from keystone.tests.unit.limit import test_backends as limit_tests
   45 from keystone.tests.unit.policy import test_backends as policy_tests
   46 from keystone.tests.unit.resource import test_backends as resource_tests
   47 from keystone.tests.unit.trust import test_backends as trust_tests
   48 from keystone.trust.backends import sql as trust_sql
   49 
   50 
   51 CONF = keystone.conf.CONF
   52 PROVIDERS = provider_api.ProviderAPIs
   53 
   54 
   55 class SqlTests(unit.SQLDriverOverrides, unit.TestCase):
   56 
   57     def setUp(self):
   58         super(SqlTests, self).setUp()
   59         self.useFixture(database.Database())
   60         self.load_backends()
   61 
   62         # populate the engine with tables & fixtures
   63         self.load_fixtures(default_fixtures)
   64         # defaulted by the data load
   65         self.user_foo['enabled'] = True
   66 
   67     def config_files(self):
   68         config_files = super(SqlTests, self).config_files()
   69         config_files.append(unit.dirs.tests_conf('backend_sql.conf'))
   70         return config_files
   71 
   72 
   73 class DataTypeRoundTrips(SqlTests):
   74     def test_json_blob_roundtrip(self):
   75         """Test round-trip of a JSON data structure with JsonBlob."""
   76         with sql.session_for_read() as session:
   77             val = session.scalar(
   78                 sqlalchemy.select(
   79                     [sqlalchemy.literal({"key": "value"}, type_=core.JsonBlob)]
   80                 )
   81             )
   82 
   83         self.assertEqual({"key": "value"}, val)
   84 
   85     def test_json_blob_sql_null(self):
   86         """Test that JsonBlob can accommodate a SQL NULL value in a result set.
   87 
   88         SQL NULL may be handled by JsonBlob in the case where a table is
   89         storing NULL in a JsonBlob column, as several models use this type
   90         in a column that is nullable.   It also comes back when the column
   91         is left NULL from being in an OUTER JOIN.  In Python, this means
   92         the None constant is handled by the datatype.
   93 
   94         """
   95         with sql.session_for_read() as session:
   96             val = session.scalar(
   97                 sqlalchemy.select(
   98                     [sqlalchemy.cast(sqlalchemy.null(), type_=core.JsonBlob)]
   99                 )
  100             )
  101 
  102         self.assertIsNone(val)
  103 
  104     def test_json_blob_python_none(self):
  105         """Test that JsonBlob round-trips a Python None.
  106 
  107         This is where JSON datatypes get a little nutty, in that JSON has
  108         a 'null' keyword, and JsonBlob right now will persist Python None
  109         as the json string 'null', not SQL NULL.
  110 
  111         """
  112         with sql.session_for_read() as session:
  113             val = session.scalar(
  114                 sqlalchemy.select(
  115                     [sqlalchemy.literal(None, type_=core.JsonBlob)]
  116                 )
  117             )
  118 
  119         self.assertIsNone(val)
  120 
  121     def test_json_blob_python_none_renders(self):
  122         """Test that JsonBlob actually renders JSON 'null' for Python None."""
  123         with sql.session_for_read() as session:
  124             val = session.scalar(
  125                 sqlalchemy.select(
  126                     [
  127                         sqlalchemy.cast(
  128                             sqlalchemy.literal(None, type_=core.JsonBlob),
  129                             sqlalchemy.String,
  130                         )
  131                     ]
  132                 )
  133             )
  134 
  135         self.assertEqual("null", val)
  136 
  137     def test_datetimeint_roundtrip(self):
  138         """Test round-trip of a Python datetime with DateTimeInt."""
  139         with sql.session_for_read() as session:
  140             datetime_value = datetime.datetime(2019, 5, 15, 10, 17, 55)
  141             val = session.scalar(
  142                 sqlalchemy.select(
  143                     [
  144                         sqlalchemy.literal(
  145                             datetime_value, type_=core.DateTimeInt
  146                         ),
  147                     ]
  148                 )
  149             )
  150 
  151         self.assertEqual(datetime_value, val)
  152 
  153     def test_datetimeint_persistence(self):
  154         """Test integer persistence with DateTimeInt."""
  155         with sql.session_for_read() as session:
  156             datetime_value = datetime.datetime(2019, 5, 15, 10, 17, 55)
  157             val = session.scalar(
  158                 sqlalchemy.select(
  159                     [
  160                         sqlalchemy.cast(
  161                             sqlalchemy.literal(
  162                                 datetime_value, type_=core.DateTimeInt
  163                             ),
  164                             sqlalchemy.Integer
  165                         )
  166                     ]
  167                 )
  168             )
  169 
  170         self.assertEqual(1557915475000000, val)
  171 
  172     def test_datetimeint_python_none(self):
  173         """Test round-trip of a Python None with DateTimeInt."""
  174         with sql.session_for_read() as session:
  175             val = session.scalar(
  176                 sqlalchemy.select(
  177                     [
  178                         sqlalchemy.literal(None, type_=core.DateTimeInt),
  179                     ]
  180                 )
  181             )
  182 
  183         self.assertIsNone(val)
  184 
  185 
  186 class SqlModels(SqlTests):
  187 
  188     def load_table(self, name):
  189         table = sqlalchemy.Table(name,
  190                                  sql.ModelBase.metadata,
  191                                  autoload=True)
  192         return table
  193 
  194     def assertExpectedSchema(self, table, expected_schema):
  195         """Assert that a table's schema is what we expect.
  196 
  197         :param string table: the name of the table to inspect
  198         :param tuple expected_schema: a tuple of tuples containing the
  199             expected schema
  200         :raises AssertionError: when the database schema doesn't match the
  201             expected schema
  202 
  203         The expected_schema format is simply::
  204 
  205             (
  206                 ('column name', sql type, qualifying detail),
  207                 ...
  208             )
  209 
  210         The qualifying detail varies based on the type of the column::
  211 
  212           - sql.Boolean columns must indicate the column's default value or
  213             None if there is no default
  214           - Columns with a length, like sql.String, must indicate the
  215             column's length
  216           - All other column types should use None
  217 
  218         Example::
  219 
  220             cols = (('id', sql.String, 64),
  221                     ('enabled', sql.Boolean, True),
  222                     ('extra', sql.JsonBlob, None))
  223             self.assertExpectedSchema('table_name', cols)
  224 
  225         """
  226         table = self.load_table(table)
  227 
  228         actual_schema = []
  229         for column in table.c:
  230             if isinstance(column.type, sql.Boolean):
  231                 default = None
  232                 if column.default:
  233                     default = column.default.arg
  234                 actual_schema.append((column.name, type(column.type), default))
  235             elif (hasattr(column.type, 'length') and
  236                     not isinstance(column.type, sql.Enum)):
  237                 # NOTE(dstanek): Even though sql.Enum columns have a length
  238                 # set we don't want to catch them here. Maybe in the future
  239                 # we'll check to see that they contain a list of the correct
  240                 # possible values.
  241                 actual_schema.append((column.name,
  242                                       type(column.type),
  243                                       column.type.length))
  244             else:
  245                 actual_schema.append((column.name, type(column.type), None))
  246 
  247         self.assertItemsEqual(expected_schema, actual_schema)
  248 
  249     def test_user_model(self):
  250         cols = (('id', sql.String, 64),
  251                 ('domain_id', sql.String, 64),
  252                 ('default_project_id', sql.String, 64),
  253                 ('enabled', sql.Boolean, None),
  254                 ('extra', sql.JsonBlob, None),
  255                 ('created_at', sql.DateTime, None),
  256                 ('last_active_at', sqlalchemy.Date, None))
  257         self.assertExpectedSchema('user', cols)
  258 
  259     def test_local_user_model(self):
  260         cols = (('id', sql.Integer, None),
  261                 ('user_id', sql.String, 64),
  262                 ('name', sql.String, 255),
  263                 ('domain_id', sql.String, 64),
  264                 ('failed_auth_count', sql.Integer, None),
  265                 ('failed_auth_at', sql.DateTime, None))
  266         self.assertExpectedSchema('local_user', cols)
  267 
  268     def test_password_model(self):
  269         cols = (('id', sql.Integer, None),
  270                 ('local_user_id', sql.Integer, None),
  271                 ('password_hash', sql.String, 255),
  272                 ('created_at', sql.DateTime, None),
  273                 ('expires_at', sql.DateTime, None),
  274                 ('created_at_int', sql.DateTimeInt, None),
  275                 ('expires_at_int', sql.DateTimeInt, None),
  276                 ('self_service', sql.Boolean, False))
  277         self.assertExpectedSchema('password', cols)
  278 
  279     def test_federated_user_model(self):
  280         cols = (('id', sql.Integer, None),
  281                 ('user_id', sql.String, 64),
  282                 ('idp_id', sql.String, 64),
  283                 ('protocol_id', sql.String, 64),
  284                 ('unique_id', sql.String, 255),
  285                 ('display_name', sql.String, 255))
  286         self.assertExpectedSchema('federated_user', cols)
  287 
  288     def test_nonlocal_user_model(self):
  289         cols = (('domain_id', sql.String, 64),
  290                 ('name', sql.String, 255),
  291                 ('user_id', sql.String, 64))
  292         self.assertExpectedSchema('nonlocal_user', cols)
  293 
  294     def test_group_model(self):
  295         cols = (('id', sql.String, 64),
  296                 ('name', sql.String, 64),
  297                 ('description', sql.Text, None),
  298                 ('domain_id', sql.String, 64),
  299                 ('extra', sql.JsonBlob, None))
  300         self.assertExpectedSchema('group', cols)
  301 
  302     def test_project_model(self):
  303         cols = (('id', sql.String, 64),
  304                 ('name', sql.String, 64),
  305                 ('description', sql.Text, None),
  306                 ('domain_id', sql.String, 64),
  307                 ('enabled', sql.Boolean, None),
  308                 ('extra', sql.JsonBlob, None),
  309                 ('parent_id', sql.String, 64),
  310                 ('is_domain', sql.Boolean, False))
  311         self.assertExpectedSchema('project', cols)
  312 
  313     def test_role_assignment_model(self):
  314         cols = (('type', sql.Enum, None),
  315                 ('actor_id', sql.String, 64),
  316                 ('target_id', sql.String, 64),
  317                 ('role_id', sql.String, 64),
  318                 ('inherited', sql.Boolean, False))
  319         self.assertExpectedSchema('assignment', cols)
  320 
  321     def test_user_group_membership(self):
  322         cols = (('group_id', sql.String, 64),
  323                 ('user_id', sql.String, 64))
  324         self.assertExpectedSchema('user_group_membership', cols)
  325 
  326     def test_revocation_event_model(self):
  327         cols = (('id', sql.Integer, None),
  328                 ('domain_id', sql.String, 64),
  329                 ('project_id', sql.String, 64),
  330                 ('user_id', sql.String, 64),
  331                 ('role_id', sql.String, 64),
  332                 ('trust_id', sql.String, 64),
  333                 ('consumer_id', sql.String, 64),
  334                 ('access_token_id', sql.String, 64),
  335                 ('issued_before', sql.DateTime, None),
  336                 ('expires_at', sql.DateTime, None),
  337                 ('revoked_at', sql.DateTime, None),
  338                 ('audit_id', sql.String, 32),
  339                 ('audit_chain_id', sql.String, 32))
  340         self.assertExpectedSchema('revocation_event', cols)
  341 
  342     def test_project_tags_model(self):
  343         cols = (('project_id', sql.String, 64),
  344                 ('name', sql.Unicode, 255))
  345         self.assertExpectedSchema('project_tag', cols)
  346 
  347 
  348 class SqlIdentity(SqlTests,
  349                   identity_tests.IdentityTests,
  350                   assignment_tests.AssignmentTests,
  351                   assignment_tests.SystemAssignmentTests,
  352                   resource_tests.ResourceTests):
  353     def test_password_hashed(self):
  354         with sql.session_for_read() as session:
  355             user_ref = PROVIDERS.identity_api._get_user(
  356                 session, self.user_foo['id']
  357             )
  358             self.assertNotEqual(self.user_foo['password'],
  359                                 user_ref['password'])
  360 
  361     def test_create_user_with_null_password(self):
  362         user_dict = unit.new_user_ref(
  363             domain_id=CONF.identity.default_domain_id)
  364         user_dict["password"] = None
  365         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  366         with sql.session_for_read() as session:
  367             new_user_ref = PROVIDERS.identity_api._get_user(
  368                 session, new_user_dict['id']
  369             )
  370             self.assertIsNone(new_user_ref.password)
  371 
  372     def test_update_user_with_null_password(self):
  373         user_dict = unit.new_user_ref(
  374             domain_id=CONF.identity.default_domain_id)
  375         self.assertTrue(user_dict['password'])
  376         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  377         new_user_dict["password"] = None
  378         new_user_dict = PROVIDERS.identity_api.update_user(
  379             new_user_dict['id'], new_user_dict
  380         )
  381         with sql.session_for_read() as session:
  382             new_user_ref = PROVIDERS.identity_api._get_user(
  383                 session, new_user_dict['id']
  384             )
  385             self.assertIsNone(new_user_ref.password)
  386 
  387     def test_delete_user_with_project_association(self):
  388         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  389         user = PROVIDERS.identity_api.create_user(user)
  390         role_member = unit.new_role_ref()
  391         PROVIDERS.role_api.create_role(role_member['id'], role_member)
  392         PROVIDERS.assignment_api.add_role_to_user_and_project(
  393             user['id'], self.project_bar['id'], role_member['id']
  394         )
  395         PROVIDERS.identity_api.delete_user(user['id'])
  396         self.assertRaises(exception.UserNotFound,
  397                           PROVIDERS.assignment_api.list_projects_for_user,
  398                           user['id'])
  399 
  400     def test_create_user_case_sensitivity(self):
  401         # user name case sensitivity is down to the fact that it is marked as
  402         # an SQL UNIQUE column, which may not be valid for other backends, like
  403         # LDAP.
  404 
  405         # create a ref with a lowercase name
  406         ref = unit.new_user_ref(name=uuid.uuid4().hex.lower(),
  407                                 domain_id=CONF.identity.default_domain_id)
  408         ref = PROVIDERS.identity_api.create_user(ref)
  409 
  410         # assign a new ID with the same name, but this time in uppercase
  411         ref['name'] = ref['name'].upper()
  412         PROVIDERS.identity_api.create_user(ref)
  413 
  414     def test_create_project_case_sensitivity(self):
  415         # project name case sensitivity is down to the fact that it is marked
  416         # as an SQL UNIQUE column, which may not be valid for other backends,
  417         # like LDAP.
  418 
  419         # create a ref with a lowercase name
  420         ref = unit.new_project_ref(domain_id=CONF.identity.default_domain_id)
  421         PROVIDERS.resource_api.create_project(ref['id'], ref)
  422 
  423         # assign a new ID with the same name, but this time in uppercase
  424         ref['id'] = uuid.uuid4().hex
  425         ref['name'] = ref['name'].upper()
  426         PROVIDERS.resource_api.create_project(ref['id'], ref)
  427 
  428     def test_delete_project_with_user_association(self):
  429         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  430         user = PROVIDERS.identity_api.create_user(user)
  431         role_member = unit.new_role_ref()
  432         PROVIDERS.role_api.create_role(role_member['id'], role_member)
  433         PROVIDERS.assignment_api.add_role_to_user_and_project(
  434             user['id'], self.project_bar['id'], role_member['id']
  435         )
  436         PROVIDERS.resource_api.delete_project(self.project_bar['id'])
  437         projects = PROVIDERS.assignment_api.list_projects_for_user(user['id'])
  438         self.assertEqual([], projects)
  439 
  440     def test_update_project_returns_extra(self):
  441         """Test for backward compatibility with an essex/folsom bug.
  442 
  443         Non-indexed attributes were returned in an 'extra' attribute, instead
  444         of on the entity itself; for consistency and backwards compatibility,
  445         those attributes should be included twice.
  446 
  447         This behavior is specific to the SQL driver.
  448 
  449         """
  450         arbitrary_key = uuid.uuid4().hex
  451         arbitrary_value = uuid.uuid4().hex
  452         project = unit.new_project_ref(
  453             domain_id=CONF.identity.default_domain_id)
  454         project[arbitrary_key] = arbitrary_value
  455         ref = PROVIDERS.resource_api.create_project(project['id'], project)
  456         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  457         self.assertNotIn('extra', ref)
  458 
  459         ref['name'] = uuid.uuid4().hex
  460         ref = PROVIDERS.resource_api.update_project(ref['id'], ref)
  461         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  462         self.assertEqual(arbitrary_value, ref['extra'][arbitrary_key])
  463 
  464     def test_update_user_returns_extra(self):
  465         """Test for backwards-compatibility with an essex/folsom bug.
  466 
  467         Non-indexed attributes were returned in an 'extra' attribute, instead
  468         of on the entity itself; for consistency and backwards compatibility,
  469         those attributes should be included twice.
  470 
  471         This behavior is specific to the SQL driver.
  472 
  473         """
  474         arbitrary_key = uuid.uuid4().hex
  475         arbitrary_value = uuid.uuid4().hex
  476         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  477         user[arbitrary_key] = arbitrary_value
  478         del user["id"]
  479         ref = PROVIDERS.identity_api.create_user(user)
  480         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  481         self.assertNotIn('password', ref)
  482         self.assertNotIn('extra', ref)
  483 
  484         user['name'] = uuid.uuid4().hex
  485         user['password'] = uuid.uuid4().hex
  486         ref = PROVIDERS.identity_api.update_user(ref['id'], user)
  487         self.assertNotIn('password', ref)
  488         self.assertNotIn('password', ref['extra'])
  489         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  490         self.assertEqual(arbitrary_value, ref['extra'][arbitrary_key])
  491 
  492     def test_sql_user_to_dict_null_default_project_id(self):
  493         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  494         user = PROVIDERS.identity_api.create_user(user)
  495         with sql.session_for_read() as session:
  496             query = session.query(identity_sql.User)
  497             query = query.filter_by(id=user['id'])
  498             raw_user_ref = query.one()
  499             self.assertIsNone(raw_user_ref.default_project_id)
  500             user_ref = raw_user_ref.to_dict()
  501             self.assertNotIn('default_project_id', user_ref)
  502             session.close()
  503 
  504     def test_list_domains_for_user(self):
  505         domain = unit.new_domain_ref()
  506         PROVIDERS.resource_api.create_domain(domain['id'], domain)
  507         user = unit.new_user_ref(domain_id=domain['id'])
  508 
  509         test_domain1 = unit.new_domain_ref()
  510         PROVIDERS.resource_api.create_domain(test_domain1['id'], test_domain1)
  511         test_domain2 = unit.new_domain_ref()
  512         PROVIDERS.resource_api.create_domain(test_domain2['id'], test_domain2)
  513 
  514         user = PROVIDERS.identity_api.create_user(user)
  515         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  516             user['id']
  517         )
  518         self.assertEqual(0, len(user_domains))
  519         PROVIDERS.assignment_api.create_grant(
  520             user_id=user['id'], domain_id=test_domain1['id'],
  521             role_id=self.role_member['id']
  522         )
  523         PROVIDERS.assignment_api.create_grant(
  524             user_id=user['id'], domain_id=test_domain2['id'],
  525             role_id=self.role_member['id']
  526         )
  527         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  528             user['id']
  529         )
  530         self.assertThat(user_domains, matchers.HasLength(2))
  531 
  532     def test_list_domains_for_user_with_grants(self):
  533         # Create two groups each with a role on a different domain, and
  534         # make user1 a member of both groups.  Both these new domains
  535         # should now be included, along with any direct user grants.
  536         domain = unit.new_domain_ref()
  537         PROVIDERS.resource_api.create_domain(domain['id'], domain)
  538         user = unit.new_user_ref(domain_id=domain['id'])
  539         user = PROVIDERS.identity_api.create_user(user)
  540         group1 = unit.new_group_ref(domain_id=domain['id'])
  541         group1 = PROVIDERS.identity_api.create_group(group1)
  542         group2 = unit.new_group_ref(domain_id=domain['id'])
  543         group2 = PROVIDERS.identity_api.create_group(group2)
  544 
  545         test_domain1 = unit.new_domain_ref()
  546         PROVIDERS.resource_api.create_domain(test_domain1['id'], test_domain1)
  547         test_domain2 = unit.new_domain_ref()
  548         PROVIDERS.resource_api.create_domain(test_domain2['id'], test_domain2)
  549         test_domain3 = unit.new_domain_ref()
  550         PROVIDERS.resource_api.create_domain(test_domain3['id'], test_domain3)
  551 
  552         PROVIDERS.identity_api.add_user_to_group(user['id'], group1['id'])
  553         PROVIDERS.identity_api.add_user_to_group(user['id'], group2['id'])
  554 
  555         # Create 3 grants, one user grant, the other two as group grants
  556         PROVIDERS.assignment_api.create_grant(
  557             user_id=user['id'], domain_id=test_domain1['id'],
  558             role_id=self.role_member['id']
  559         )
  560         PROVIDERS.assignment_api.create_grant(
  561             group_id=group1['id'], domain_id=test_domain2['id'],
  562             role_id=self.role_admin['id']
  563         )
  564         PROVIDERS.assignment_api.create_grant(
  565             group_id=group2['id'], domain_id=test_domain3['id'],
  566             role_id=self.role_admin['id']
  567         )
  568         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  569             user['id']
  570         )
  571         self.assertThat(user_domains, matchers.HasLength(3))
  572 
  573     def test_list_domains_for_user_with_inherited_grants(self):
  574         """Test that inherited roles on the domain are excluded.
  575 
  576         Test Plan:
  577 
  578         - Create two domains, one user, group and role
  579         - Domain1 is given an inherited user role, Domain2 an inherited
  580           group role (for a group of which the user is a member)
  581         - When listing domains for user, neither domain should be returned
  582 
  583         """
  584         domain1 = unit.new_domain_ref()
  585         domain1 = PROVIDERS.resource_api.create_domain(domain1['id'], domain1)
  586         domain2 = unit.new_domain_ref()
  587         domain2 = PROVIDERS.resource_api.create_domain(domain2['id'], domain2)
  588         user = unit.new_user_ref(domain_id=domain1['id'])
  589         user = PROVIDERS.identity_api.create_user(user)
  590         group = unit.new_group_ref(domain_id=domain1['id'])
  591         group = PROVIDERS.identity_api.create_group(group)
  592         PROVIDERS.identity_api.add_user_to_group(user['id'], group['id'])
  593         role = unit.new_role_ref()
  594         PROVIDERS.role_api.create_role(role['id'], role)
  595 
  596         # Create a grant on each domain, one user grant, one group grant,
  597         # both inherited.
  598         PROVIDERS.assignment_api.create_grant(
  599             user_id=user['id'], domain_id=domain1['id'], role_id=role['id'],
  600             inherited_to_projects=True
  601         )
  602         PROVIDERS.assignment_api.create_grant(
  603             group_id=group['id'], domain_id=domain2['id'], role_id=role['id'],
  604             inherited_to_projects=True
  605         )
  606 
  607         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  608             user['id']
  609         )
  610         # No domains should be returned since both domains have only inherited
  611         # roles assignments.
  612         self.assertThat(user_domains, matchers.HasLength(0))
  613 
  614     def test_list_groups_for_user(self):
  615         domain = self._get_domain_fixture()
  616         test_groups = []
  617         test_users = []
  618         GROUP_COUNT = 3
  619         USER_COUNT = 2
  620 
  621         for x in range(0, USER_COUNT):
  622             new_user = unit.new_user_ref(domain_id=domain['id'])
  623             new_user = PROVIDERS.identity_api.create_user(new_user)
  624             test_users.append(new_user)
  625         positive_user = test_users[0]
  626         negative_user = test_users[1]
  627 
  628         for x in range(0, USER_COUNT):
  629             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  630                 test_users[x]['id'])
  631             self.assertEqual(0, len(group_refs))
  632 
  633         for x in range(0, GROUP_COUNT):
  634             before_count = x
  635             after_count = x + 1
  636             new_group = unit.new_group_ref(domain_id=domain['id'])
  637             new_group = PROVIDERS.identity_api.create_group(new_group)
  638             test_groups.append(new_group)
  639 
  640             # add the user to the group and ensure that the
  641             # group count increases by one for each
  642             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  643                 positive_user['id'])
  644             self.assertEqual(before_count, len(group_refs))
  645             PROVIDERS.identity_api.add_user_to_group(
  646                 positive_user['id'],
  647                 new_group['id'])
  648             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  649                 positive_user['id'])
  650             self.assertEqual(after_count, len(group_refs))
  651 
  652             # Make sure the group count for the unrelated user did not change
  653             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  654                 negative_user['id'])
  655             self.assertEqual(0, len(group_refs))
  656 
  657         # remove the user from each group and ensure that
  658         # the group count reduces by one for each
  659         for x in range(0, 3):
  660             before_count = GROUP_COUNT - x
  661             after_count = GROUP_COUNT - x - 1
  662             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  663                 positive_user['id'])
  664             self.assertEqual(before_count, len(group_refs))
  665             PROVIDERS.identity_api.remove_user_from_group(
  666                 positive_user['id'],
  667                 test_groups[x]['id'])
  668             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  669                 positive_user['id'])
  670             self.assertEqual(after_count, len(group_refs))
  671             # Make sure the group count for the unrelated user
  672             # did not change
  673             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  674                 negative_user['id'])
  675             self.assertEqual(0, len(group_refs))
  676 
  677     def test_add_user_to_group_expiring_mapped(self):
  678         self._build_fed_resource()
  679         domain = self._get_domain_fixture()
  680         self.config_fixture.config(group='federation',
  681                                    default_authorization_ttl=5)
  682         time = datetime.datetime.utcnow()
  683         tick = datetime.timedelta(minutes=5)
  684 
  685         new_group = unit.new_group_ref(domain_id=domain['id'])
  686         new_group = PROVIDERS.identity_api.create_group(new_group)
  687 
  688         fed_dict = unit.new_federated_user_ref()
  689         fed_dict['idp_id'] = 'myidp'
  690         fed_dict['protocol_id'] = 'mapped'
  691 
  692         with freezegun.freeze_time(time - tick) as frozen_time:
  693             user = PROVIDERS.identity_api.shadow_federated_user(
  694                 **fed_dict, group_ids=[new_group['id']])
  695 
  696             PROVIDERS.identity_api.check_user_in_group(user['id'],
  697                                                        new_group['id'])
  698 
  699             # Expiration
  700             frozen_time.tick(tick)
  701             self.assertRaises(exception.NotFound,
  702                               PROVIDERS.identity_api.check_user_in_group,
  703                               user['id'],
  704                               new_group['id'])
  705 
  706             # Renewal
  707             PROVIDERS.identity_api.shadow_federated_user(
  708                 **fed_dict, group_ids=[new_group['id']])
  709             PROVIDERS.identity_api.check_user_in_group(user['id'],
  710                                                        new_group['id'])
  711 
  712     def test_add_user_to_group_expiring(self):
  713         self._build_fed_resource()
  714         domain = self._get_domain_fixture()
  715         time = datetime.datetime.utcnow()
  716         tick = datetime.timedelta(minutes=5)
  717 
  718         new_group = unit.new_group_ref(domain_id=domain['id'])
  719         new_group = PROVIDERS.identity_api.create_group(new_group)
  720 
  721         fed_dict = unit.new_federated_user_ref()
  722         fed_dict['idp_id'] = 'myidp'
  723         fed_dict['protocol_id'] = 'mapped'
  724         new_user = PROVIDERS.shadow_users_api.create_federated_user(
  725             domain['id'], fed_dict
  726         )
  727 
  728         with freezegun.freeze_time(time - tick) as frozen_time:
  729             PROVIDERS.shadow_users_api.add_user_to_group_expires(
  730                 new_user['id'], new_group['id'])
  731 
  732             self.config_fixture.config(group='federation',
  733                                        default_authorization_ttl=0)
  734             self.assertRaises(exception.NotFound,
  735                               PROVIDERS.identity_api.check_user_in_group,
  736                               new_user['id'],
  737                               new_group['id'])
  738 
  739             self.config_fixture.config(group='federation',
  740                                        default_authorization_ttl=5)
  741             PROVIDERS.identity_api.check_user_in_group(new_user['id'],
  742                                                        new_group['id'])
  743 
  744             # Expiration
  745             frozen_time.tick(tick)
  746             self.assertRaises(exception.NotFound,
  747                               PROVIDERS.identity_api.check_user_in_group,
  748                               new_user['id'],
  749                               new_group['id'])
  750 
  751             # Renewal
  752             PROVIDERS.shadow_users_api.add_user_to_group_expires(
  753                 new_user['id'], new_group['id'])
  754             PROVIDERS.identity_api.check_user_in_group(new_user['id'],
  755                                                        new_group['id'])
  756 
  757     def test_add_user_to_group_expiring_list(self):
  758         self._build_fed_resource()
  759         domain = self._get_domain_fixture()
  760         self.config_fixture.config(group='federation',
  761                                    default_authorization_ttl=5)
  762         time = datetime.datetime.utcnow()
  763         tick = datetime.timedelta(minutes=5)
  764 
  765         new_group = unit.new_group_ref(domain_id=domain['id'])
  766         new_group = PROVIDERS.identity_api.create_group(new_group)
  767         exp_new_group = unit.new_group_ref(domain_id=domain['id'])
  768         exp_new_group = PROVIDERS.identity_api.create_group(exp_new_group)
  769 
  770         fed_dict = unit.new_federated_user_ref()
  771         fed_dict['idp_id'] = 'myidp'
  772         fed_dict['protocol_id'] = 'mapped'
  773         new_user = PROVIDERS.shadow_users_api.create_federated_user(
  774             domain['id'], fed_dict
  775         )
  776 
  777         PROVIDERS.identity_api.add_user_to_group(new_user['id'],
  778                                                  new_group['id'])
  779         PROVIDERS.identity_api.check_user_in_group(new_user['id'],
  780                                                    new_group['id'])
  781 
  782         with freezegun.freeze_time(time - tick) as frozen_time:
  783             PROVIDERS.shadow_users_api.add_user_to_group_expires(
  784                 new_user['id'], exp_new_group['id'])
  785             PROVIDERS.identity_api.check_user_in_group(new_user['id'],
  786                                                        new_group['id'])
  787 
  788             groups = PROVIDERS.identity_api.list_groups_for_user(
  789                 new_user['id'])
  790             self.assertEqual(len(groups), 2)
  791             for group in groups:
  792                 if group.get('membership_expires_at'):
  793                     self.assertEqual(group['membership_expires_at'], time)
  794 
  795             frozen_time.tick(tick)
  796             groups = PROVIDERS.identity_api.list_groups_for_user(
  797                 new_user['id'])
  798             self.assertEqual(len(groups), 1)
  799 
  800     def test_storing_null_domain_id_in_project_ref(self):
  801         """Test the special storage of domain_id=None in sql resource driver.
  802 
  803         The resource driver uses a special value in place of None for domain_id
  804         in the project record. This shouldn't escape the driver. Hence we test
  805         the interface to ensure that you can store a domain_id of None, and
  806         that any special value used inside the driver does not escape through
  807         the interface.
  808 
  809         """
  810         spoiler_project = unit.new_project_ref(
  811             domain_id=CONF.identity.default_domain_id)
  812         PROVIDERS.resource_api.create_project(
  813             spoiler_project['id'], spoiler_project
  814         )
  815 
  816         # First let's create a project with a None domain_id and make sure we
  817         # can read it back.
  818         project = unit.new_project_ref(domain_id=None, is_domain=True)
  819         project = PROVIDERS.resource_api.create_project(project['id'], project)
  820         ref = PROVIDERS.resource_api.get_project(project['id'])
  821         self.assertDictEqual(project, ref)
  822 
  823         # Can we get it by name?
  824         ref = PROVIDERS.resource_api.get_project_by_name(project['name'], None)
  825         self.assertDictEqual(project, ref)
  826 
  827         # Can we filter for them - create a second domain to ensure we are
  828         # testing the receipt of more than one.
  829         project2 = unit.new_project_ref(domain_id=None, is_domain=True)
  830         project2 = PROVIDERS.resource_api.create_project(
  831             project2['id'], project2
  832         )
  833         hints = driver_hints.Hints()
  834         hints.add_filter('domain_id', None)
  835         refs = PROVIDERS.resource_api.list_projects(hints)
  836         self.assertThat(refs, matchers.HasLength(2 + self.domain_count))
  837         self.assertIn(project, refs)
  838         self.assertIn(project2, refs)
  839 
  840         # Can we update it?
  841         project['name'] = uuid.uuid4().hex
  842         PROVIDERS.resource_api.update_project(project['id'], project)
  843         ref = PROVIDERS.resource_api.get_project(project['id'])
  844         self.assertDictEqual(project, ref)
  845 
  846         # Finally, make sure we can delete it
  847         project['enabled'] = False
  848         PROVIDERS.resource_api.update_project(project['id'], project)
  849         PROVIDERS.resource_api.delete_project(project['id'])
  850         self.assertRaises(exception.ProjectNotFound,
  851                           PROVIDERS.resource_api.get_project,
  852                           project['id'])
  853 
  854     def test_hidden_project_domain_root_is_really_hidden(self):
  855         """Ensure we cannot access the hidden root of all project domains.
  856 
  857         Calling any of the driver methods should result in the same as
  858         would be returned if we passed a project that does not exist. We don't
  859         test create_project, since we do not allow a caller of our API to
  860         specify their own ID for a new entity.
  861 
  862         """
  863         def _exercise_project_api(ref_id):
  864             driver = PROVIDERS.resource_api.driver
  865             self.assertRaises(exception.ProjectNotFound,
  866                               driver.get_project,
  867                               ref_id)
  868 
  869             self.assertRaises(exception.ProjectNotFound,
  870                               driver.get_project_by_name,
  871                               resource.NULL_DOMAIN_ID,
  872                               ref_id)
  873 
  874             project_ids = [x['id'] for x in
  875                            driver.list_projects(driver_hints.Hints())]
  876             self.assertNotIn(ref_id, project_ids)
  877 
  878             projects = driver.list_projects_from_ids([ref_id])
  879             self.assertThat(projects, matchers.HasLength(0))
  880 
  881             project_ids = [x for x in
  882                            driver.list_project_ids_from_domain_ids([ref_id])]
  883             self.assertNotIn(ref_id, project_ids)
  884 
  885             self.assertRaises(exception.DomainNotFound,
  886                               driver.list_projects_in_domain,
  887                               ref_id)
  888 
  889             project_ids = [
  890                 x['id'] for x in
  891                 driver.list_projects_acting_as_domain(driver_hints.Hints())]
  892             self.assertNotIn(ref_id, project_ids)
  893 
  894             projects = driver.list_projects_in_subtree(ref_id)
  895             self.assertThat(projects, matchers.HasLength(0))
  896 
  897             self.assertRaises(exception.ProjectNotFound,
  898                               driver.list_project_parents,
  899                               ref_id)
  900 
  901             # A non-existing project just returns True from the driver
  902             self.assertTrue(driver.is_leaf_project(ref_id))
  903 
  904             self.assertRaises(exception.ProjectNotFound,
  905                               driver.update_project,
  906                               ref_id,
  907                               {})
  908 
  909             self.assertRaises(exception.ProjectNotFound,
  910                               driver.delete_project,
  911                               ref_id)
  912 
  913             # Deleting list of projects that includes a non-existing project
  914             # should be silent. The root domain <<keystone.domain.root>> can't
  915             # be deleted.
  916             if ref_id != resource.NULL_DOMAIN_ID:
  917                 driver.delete_projects_from_ids([ref_id])
  918 
  919         _exercise_project_api(uuid.uuid4().hex)
  920         _exercise_project_api(resource.NULL_DOMAIN_ID)
  921 
  922     def test_list_users_call_count(self):
  923         """There should not be O(N) queries."""
  924         # create 10 users. 10 is just a random number
  925         for i in range(10):
  926             user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  927             PROVIDERS.identity_api.create_user(user)
  928 
  929         # sqlalchemy emits various events and allows to listen to them. Here
  930         # bound method `query_counter` will be called each time when a query
  931         # is compiled
  932         class CallCounter(object):
  933             def __init__(self):
  934                 self.calls = 0
  935 
  936             def reset(self):
  937                 self.calls = 0
  938 
  939             def query_counter(self, query):
  940                 self.calls += 1
  941 
  942         counter = CallCounter()
  943         sqlalchemy.event.listen(sqlalchemy.orm.query.Query, 'before_compile',
  944                                 counter.query_counter)
  945 
  946         first_call_users = PROVIDERS.identity_api.list_users()
  947         first_call_counter = counter.calls
  948         # add 10 more users
  949         for i in range(10):
  950             user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  951             PROVIDERS.identity_api.create_user(user)
  952         counter.reset()
  953         second_call_users = PROVIDERS.identity_api.list_users()
  954         # ensure that the number of calls does not depend on the number of
  955         # users fetched.
  956         self.assertNotEqual(len(first_call_users), len(second_call_users))
  957         self.assertEqual(first_call_counter, counter.calls)
  958         self.assertEqual(3, counter.calls)
  959 
  960     def test_check_project_depth(self):
  961         # Create a 3 level project tree:
  962         #
  963         # default_domain
  964         #       |
  965         #   project_1
  966         #       |
  967         #   project_2
  968         project_1 = unit.new_project_ref(
  969             domain_id=CONF.identity.default_domain_id)
  970         PROVIDERS.resource_api.create_project(project_1['id'], project_1)
  971         project_2 = unit.new_project_ref(
  972             domain_id=CONF.identity.default_domain_id,
  973             parent_id=project_1['id'])
  974         PROVIDERS.resource_api.create_project(project_2['id'], project_2)
  975 
  976         # if max_depth is None or >= current project depth, return nothing.
  977         resp = PROVIDERS.resource_api.check_project_depth(max_depth=None)
  978         self.assertIsNone(resp)
  979         resp = PROVIDERS.resource_api.check_project_depth(max_depth=3)
  980         self.assertIsNone(resp)
  981         resp = PROVIDERS.resource_api.check_project_depth(max_depth=4)
  982         self.assertIsNone(resp)
  983         # if max_depth < current project depth, raise LimitTreeExceedError
  984         self.assertRaises(exception.LimitTreeExceedError,
  985                           PROVIDERS.resource_api.check_project_depth,
  986                           2)
  987 
  988     def test_update_user_with_stale_data_forces_retry(self):
  989         # Capture log output so we know oslo.db attempted a retry
  990         log_fixture = self.useFixture(fixtures.FakeLogger(level=log.DEBUG))
  991 
  992         # Create a new user
  993         user_dict = unit.new_user_ref(
  994             domain_id=CONF.identity.default_domain_id)
  995         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  996 
  997         side_effects = [
  998             # Raise a StaleDataError simulating that another client has
  999             # updated the user's password while this client's request was
 1000             # being processed
 1001             sqlalchemy.orm.exc.StaleDataError,
 1002             # The oslo.db library will retry the request, so the second
 1003             # time this method is called let's return a valid session
 1004             # object
 1005             sql.session_for_write()
 1006         ]
 1007         with mock.patch('keystone.common.sql.session_for_write') as m:
 1008             m.side_effect = side_effects
 1009 
 1010             # Update a user's attribute, the first attempt will fail but
 1011             # oslo.db will handle the exception and retry, the second attempt
 1012             # will succeed
 1013             new_user_dict['email'] = uuid.uuid4().hex
 1014             PROVIDERS.identity_api.update_user(
 1015                 new_user_dict['id'], new_user_dict)
 1016 
 1017         # Make sure oslo.db retried the update by checking the log output
 1018         expected_log_message = (
 1019             'Performing DB retry for function keystone.identity.backends.'
 1020             'sql.Identity.update_user'
 1021         )
 1022         self.assertIn(expected_log_message, log_fixture.output)
 1023 
 1024 
 1025 class SqlTrust(SqlTests, trust_tests.TrustTests):
 1026 
 1027     def test_trust_expires_at_int_matches_expires_at(self):
 1028         with sql.session_for_write() as session:
 1029             new_id = uuid.uuid4().hex
 1030             self.create_sample_trust(new_id)
 1031             trust_ref = session.query(trust_sql.TrustModel).get(new_id)
 1032             self.assertIsNotNone(trust_ref._expires_at)
 1033             self.assertEqual(trust_ref._expires_at, trust_ref.expires_at_int)
 1034             self.assertEqual(trust_ref.expires_at, trust_ref.expires_at_int)
 1035 
 1036 
 1037 class SqlCatalog(SqlTests, catalog_tests.CatalogTests):
 1038 
 1039     _legacy_endpoint_id_in_endpoint = True
 1040     _enabled_default_to_true_when_creating_endpoint = True
 1041 
 1042     def test_get_v3_catalog_project_non_exist(self):
 1043         service = unit.new_service_ref()
 1044         PROVIDERS.catalog_api.create_service(service['id'], service)
 1045 
 1046         malformed_url = "http://192.168.1.104:8774/v2/$(project)s"
 1047         endpoint = unit.new_endpoint_ref(service_id=service['id'],
 1048                                          url=malformed_url,
 1049                                          region_id=None)
 1050         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
 1051         self.assertRaises(exception.ProjectNotFound,
 1052                           PROVIDERS.catalog_api.get_v3_catalog,
 1053                           'fake-user',
 1054                           'fake-project')
 1055 
 1056     def test_get_v3_catalog_with_empty_public_url(self):
 1057         service = unit.new_service_ref()
 1058         PROVIDERS.catalog_api.create_service(service['id'], service)
 1059 
 1060         endpoint = unit.new_endpoint_ref(url='', service_id=service['id'],
 1061                                          region_id=None)
 1062         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
 1063 
 1064         catalog = PROVIDERS.catalog_api.get_v3_catalog(self.user_foo['id'],
 1065                                                        self.project_bar['id'])
 1066         catalog_endpoint = catalog[0]
 1067         self.assertEqual(service['name'], catalog_endpoint['name'])
 1068         self.assertEqual(service['id'], catalog_endpoint['id'])
 1069         self.assertEqual([], catalog_endpoint['endpoints'])
 1070 
 1071     def test_create_endpoint_region_returns_not_found(self):
 1072         service = unit.new_service_ref()
 1073         PROVIDERS.catalog_api.create_service(service['id'], service)
 1074 
 1075         endpoint = unit.new_endpoint_ref(region_id=uuid.uuid4().hex,
 1076                                          service_id=service['id'])
 1077 
 1078         self.assertRaises(exception.ValidationError,
 1079                           PROVIDERS.catalog_api.create_endpoint,
 1080                           endpoint['id'],
 1081                           endpoint.copy())
 1082 
 1083     def test_create_region_invalid_id(self):
 1084         region = unit.new_region_ref(id='0' * 256)
 1085 
 1086         self.assertRaises(exception.StringLengthExceeded,
 1087                           PROVIDERS.catalog_api.create_region,
 1088                           region)
 1089 
 1090     def test_create_region_invalid_parent_id(self):
 1091         region = unit.new_region_ref(parent_region_id='0' * 256)
 1092 
 1093         self.assertRaises(exception.RegionNotFound,
 1094                           PROVIDERS.catalog_api.create_region,
 1095                           region)
 1096 
 1097     def test_delete_region_with_endpoint(self):
 1098         # create a region
 1099         region = unit.new_region_ref()
 1100         PROVIDERS.catalog_api.create_region(region)
 1101 
 1102         # create a child region
 1103         child_region = unit.new_region_ref(parent_region_id=region['id'])
 1104         PROVIDERS.catalog_api.create_region(child_region)
 1105         # create a service
 1106         service = unit.new_service_ref()
 1107         PROVIDERS.catalog_api.create_service(service['id'], service)
 1108 
 1109         # create an endpoint attached to the service and child region
 1110         child_endpoint = unit.new_endpoint_ref(region_id=child_region['id'],
 1111                                                service_id=service['id'])
 1112 
 1113         PROVIDERS.catalog_api.create_endpoint(
 1114             child_endpoint['id'], child_endpoint
 1115         )
 1116         self.assertRaises(exception.RegionDeletionError,
 1117                           PROVIDERS.catalog_api.delete_region,
 1118                           child_region['id'])
 1119 
 1120         # create an endpoint attached to the service and parent region
 1121         endpoint = unit.new_endpoint_ref(region_id=region['id'],
 1122                                          service_id=service['id'])
 1123 
 1124         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint)
 1125         self.assertRaises(exception.RegionDeletionError,
 1126                           PROVIDERS.catalog_api.delete_region,
 1127                           region['id'])
 1128 
 1129     def test_v3_catalog_domain_scoped_token(self):
 1130         # test the case that project_id is None.
 1131         srv_1 = unit.new_service_ref()
 1132         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1133         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1134                                            region_id=None)
 1135         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1136 
 1137         srv_2 = unit.new_service_ref()
 1138         PROVIDERS.catalog_api.create_service(srv_2['id'], srv_2)
 1139         endpoint_2 = unit.new_endpoint_ref(service_id=srv_2['id'],
 1140                                            region_id=None)
 1141         PROVIDERS.catalog_api.create_endpoint(endpoint_2['id'], endpoint_2)
 1142 
 1143         self.config_fixture.config(group='endpoint_filter',
 1144                                    return_all_endpoints_if_no_filter=True)
 1145         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1146             uuid.uuid4().hex, None
 1147         )
 1148         self.assertThat(catalog_ref, matchers.HasLength(2))
 1149         self.config_fixture.config(group='endpoint_filter',
 1150                                    return_all_endpoints_if_no_filter=False)
 1151         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1152             uuid.uuid4().hex, None
 1153         )
 1154         self.assertThat(catalog_ref, matchers.HasLength(0))
 1155 
 1156     def test_v3_catalog_endpoint_filter_enabled(self):
 1157         srv_1 = unit.new_service_ref()
 1158         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1159         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1160                                            region_id=None)
 1161         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1162         endpoint_2 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1163                                            region_id=None)
 1164         PROVIDERS.catalog_api.create_endpoint(endpoint_2['id'], endpoint_2)
 1165         # create endpoint-project association.
 1166         PROVIDERS.catalog_api.add_endpoint_to_project(
 1167             endpoint_1['id'],
 1168             self.project_bar['id'])
 1169 
 1170         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1171             uuid.uuid4().hex, self.project_bar['id']
 1172         )
 1173         self.assertThat(catalog_ref, matchers.HasLength(1))
 1174         self.assertThat(catalog_ref[0]['endpoints'], matchers.HasLength(1))
 1175         # the endpoint is that defined in the endpoint-project association.
 1176         self.assertEqual(endpoint_1['id'],
 1177                          catalog_ref[0]['endpoints'][0]['id'])
 1178 
 1179     def test_v3_catalog_endpoint_filter_disabled(self):
 1180         # there is no endpoint-project association defined.
 1181         self.config_fixture.config(group='endpoint_filter',
 1182                                    return_all_endpoints_if_no_filter=True)
 1183         srv_1 = unit.new_service_ref()
 1184         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1185         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1186                                            region_id=None)
 1187         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1188 
 1189         srv_2 = unit.new_service_ref()
 1190         PROVIDERS.catalog_api.create_service(srv_2['id'], srv_2)
 1191 
 1192         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1193             uuid.uuid4().hex, self.project_bar['id']
 1194         )
 1195         self.assertThat(catalog_ref, matchers.HasLength(2))
 1196         srv_id_list = [catalog_ref[0]['id'], catalog_ref[1]['id']]
 1197         self.assertItemsEqual([srv_1['id'], srv_2['id']], srv_id_list)
 1198 
 1199 
 1200 class SqlPolicy(SqlTests, policy_tests.PolicyTests):
 1201     pass
 1202 
 1203 
 1204 class SqlInheritance(SqlTests, assignment_tests.InheritanceTests):
 1205     pass
 1206 
 1207 
 1208 class SqlImpliedRoles(SqlTests, assignment_tests.ImpliedRoleTests):
 1209     pass
 1210 
 1211 
 1212 class SqlFilterTests(SqlTests, identity_tests.FilterTests):
 1213 
 1214     def clean_up_entities(self):
 1215         """Clean up entity test data from Filter Test Cases."""
 1216         for entity in ['user', 'group', 'project']:
 1217             self._delete_test_data(entity, self.entity_list[entity])
 1218             self._delete_test_data(entity, self.domain1_entity_list[entity])
 1219         del self.entity_list
 1220         del self.domain1_entity_list
 1221         self.domain1['enabled'] = False
 1222         PROVIDERS.resource_api.update_domain(self.domain1['id'], self.domain1)
 1223         PROVIDERS.resource_api.delete_domain(self.domain1['id'])
 1224         del self.domain1
 1225 
 1226     def test_list_entities_filtered_by_domain(self):
 1227         # NOTE(henry-nash): This method is here rather than in
 1228         # unit.identity.test_backends since any domain filtering with LDAP is
 1229         # handled by the manager layer (and is already tested elsewhere) not at
 1230         # the driver level.
 1231         self.addCleanup(self.clean_up_entities)
 1232         self.domain1 = unit.new_domain_ref()
 1233         PROVIDERS.resource_api.create_domain(self.domain1['id'], self.domain1)
 1234 
 1235         self.entity_list = {}
 1236         self.domain1_entity_list = {}
 1237         for entity in ['user', 'group', 'project']:
 1238             # Create 5 entities, 3 of which are in domain1
 1239             DOMAIN1_ENTITIES = 3
 1240             self.entity_list[entity] = self._create_test_data(entity, 2)
 1241             self.domain1_entity_list[entity] = self._create_test_data(
 1242                 entity, DOMAIN1_ENTITIES, self.domain1['id'])
 1243 
 1244             # Should get back the DOMAIN1_ENTITIES in domain1
 1245             hints = driver_hints.Hints()
 1246             hints.add_filter('domain_id', self.domain1['id'])
 1247             entities = self._list_entities(entity)(hints=hints)
 1248             self.assertEqual(DOMAIN1_ENTITIES, len(entities))
 1249             self._match_with_list(entities, self.domain1_entity_list[entity])
 1250             # Check the driver has removed the filter from the list hints
 1251             self.assertFalse(hints.get_exact_filter_by_name('domain_id'))
 1252 
 1253     def test_filter_sql_injection_attack(self):
 1254         """Test against sql injection attack on filters.
 1255 
 1256         Test Plan:
 1257         - Attempt to get all entities back by passing a two-term attribute
 1258         - Attempt to piggyback filter to damage DB (e.g. drop table)
 1259 
 1260         """
 1261         # Check we have some users
 1262         users = PROVIDERS.identity_api.list_users()
 1263         self.assertGreater(len(users), 0)
 1264 
 1265         hints = driver_hints.Hints()
 1266         hints.add_filter('name', "anything' or 'x'='x")
 1267         users = PROVIDERS.identity_api.list_users(hints=hints)
 1268         self.assertEqual(0, len(users))
 1269 
 1270         # See if we can add a SQL command...use the group table instead of the
 1271         # user table since 'user' is reserved word for SQLAlchemy.
 1272         group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id)
 1273         group = PROVIDERS.identity_api.create_group(group)
 1274 
 1275         hints = driver_hints.Hints()
 1276         hints.add_filter('name', "x'; drop table group")
 1277         groups = PROVIDERS.identity_api.list_groups(hints=hints)
 1278         self.assertEqual(0, len(groups))
 1279 
 1280         groups = PROVIDERS.identity_api.list_groups()
 1281         self.assertGreater(len(groups), 0)
 1282 
 1283 
 1284 class SqlLimitTests(SqlTests, identity_tests.LimitTests):
 1285     def setUp(self):
 1286         super(SqlLimitTests, self).setUp()
 1287         identity_tests.LimitTests.setUp(self)
 1288 
 1289 
 1290 class FakeTable(sql.ModelBase):
 1291     __tablename__ = 'test_table'
 1292     col = sql.Column(sql.String(32), primary_key=True)
 1293 
 1294     @sql.handle_conflicts('keystone')
 1295     def insert(self):
 1296         raise db_exception.DBDuplicateEntry
 1297 
 1298     @sql.handle_conflicts('keystone')
 1299     def update(self):
 1300         raise db_exception.DBError(
 1301             inner_exception=exc.IntegrityError('a', 'a', 'a'))
 1302 
 1303     @sql.handle_conflicts('keystone')
 1304     def lookup(self):
 1305         raise KeyError
 1306 
 1307 
 1308 class SqlDecorators(unit.TestCase):
 1309 
 1310     def test_initialization_fail(self):
 1311         self.assertRaises(exception.StringLengthExceeded,
 1312                           FakeTable, col='a' * 64)
 1313 
 1314     def test_initialization(self):
 1315         tt = FakeTable(col='a')
 1316         self.assertEqual('a', tt.col)
 1317 
 1318     def test_conflict_happend(self):
 1319         self.assertRaises(exception.Conflict, FakeTable().insert)
 1320         self.assertRaises(exception.UnexpectedError, FakeTable().update)
 1321 
 1322     def test_not_conflict_error(self):
 1323         self.assertRaises(KeyError, FakeTable().lookup)
 1324 
 1325 
 1326 class SqlModuleInitialization(unit.TestCase):
 1327 
 1328     @mock.patch.object(sql.core, 'CONF')
 1329     @mock.patch.object(options, 'set_defaults')
 1330     def test_initialize_module(self, set_defaults, CONF):
 1331         sql.initialize()
 1332         set_defaults.assert_called_with(CONF,
 1333                                         connection='sqlite:///keystone.db')
 1334 
 1335 
 1336 class SqlCredential(SqlTests):
 1337 
 1338     def _create_credential_with_user_id(self, user_id=uuid.uuid4().hex):
 1339         credential = unit.new_credential_ref(user_id=user_id,
 1340                                              extra=uuid.uuid4().hex,
 1341                                              type=uuid.uuid4().hex)
 1342         PROVIDERS.credential_api.create_credential(
 1343             credential['id'], credential
 1344         )
 1345         return credential
 1346 
 1347     def _validateCredentialList(self, retrieved_credentials,
 1348                                 expected_credentials):
 1349         self.assertEqual(len(expected_credentials), len(retrieved_credentials))
 1350         retrived_ids = [c['id'] for c in retrieved_credentials]
 1351         for cred in expected_credentials:
 1352             self.assertIn(cred['id'], retrived_ids)
 1353 
 1354     def setUp(self):
 1355         self.useFixture(database.Database())
 1356         super(SqlCredential, self).setUp()
 1357         self.useFixture(
 1358             ksfixtures.KeyRepository(
 1359                 self.config_fixture,
 1360                 'credential',
 1361                 credential_provider.MAX_ACTIVE_KEYS
 1362             )
 1363         )
 1364 
 1365         self.credentials = []
 1366         for _ in range(3):
 1367             self.credentials.append(
 1368                 self._create_credential_with_user_id())
 1369         self.user_credentials = []
 1370         for _ in range(3):
 1371             cred = self._create_credential_with_user_id(self.user_foo['id'])
 1372             self.user_credentials.append(cred)
 1373             self.credentials.append(cred)
 1374 
 1375     def test_list_credentials(self):
 1376         credentials = PROVIDERS.credential_api.list_credentials()
 1377         self._validateCredentialList(credentials, self.credentials)
 1378         # test filtering using hints
 1379         hints = driver_hints.Hints()
 1380         hints.add_filter('user_id', self.user_foo['id'])
 1381         credentials = PROVIDERS.credential_api.list_credentials(hints)
 1382         self._validateCredentialList(credentials, self.user_credentials)
 1383 
 1384     def test_list_credentials_for_user(self):
 1385         credentials = PROVIDERS.credential_api.list_credentials_for_user(
 1386             self.user_foo['id'])
 1387         self._validateCredentialList(credentials, self.user_credentials)
 1388 
 1389     def test_list_credentials_for_user_and_type(self):
 1390         cred = self.user_credentials[0]
 1391         credentials = PROVIDERS.credential_api.list_credentials_for_user(
 1392             self.user_foo['id'], type=cred['type'])
 1393         self._validateCredentialList(credentials, [cred])
 1394 
 1395     def test_create_credential_is_encrypted_when_stored(self):
 1396         credential = unit.new_credential_ref(user_id=uuid.uuid4().hex)
 1397         credential_id = credential['id']
 1398         returned_credential = PROVIDERS.credential_api.create_credential(
 1399             credential_id,
 1400             credential
 1401         )
 1402 
 1403         # Make sure the `blob` is *not* encrypted when returned from the
 1404         # credential API.
 1405         self.assertEqual(returned_credential['blob'], credential['blob'])
 1406 
 1407         credential_from_backend = (
 1408             PROVIDERS.credential_api.driver.get_credential(credential_id)
 1409         )
 1410 
 1411         # Pull the credential directly from the backend, the `blob` should be
 1412         # encrypted.
 1413         self.assertNotEqual(
 1414             credential_from_backend['encrypted_blob'],
 1415             credential['blob']
 1416         )
 1417 
 1418     def test_list_credentials_is_decrypted(self):
 1419         credential = unit.new_credential_ref(user_id=uuid.uuid4().hex)
 1420         credential_id = credential['id']
 1421 
 1422         created_credential = PROVIDERS.credential_api.create_credential(
 1423             credential_id,
 1424             credential
 1425         )
 1426 
 1427         # Pull the credential directly from the backend, the `blob` should be
 1428         # encrypted.
 1429         credential_from_backend = (
 1430             PROVIDERS.credential_api.driver.get_credential(credential_id)
 1431         )
 1432         self.assertNotEqual(
 1433             credential_from_backend['encrypted_blob'],
 1434             credential['blob']
 1435         )
 1436 
 1437         # Make sure the `blob` values listed from the API are not encrypted.
 1438         listed_credentials = PROVIDERS.credential_api.list_credentials()
 1439         self.assertIn(created_credential, listed_credentials)
 1440 
 1441 
 1442 class SqlRegisteredLimit(SqlTests, limit_tests.RegisteredLimitTests):
 1443 
 1444     def setUp(self):
 1445         super(SqlRegisteredLimit, self).setUp()
 1446 
 1447         fixtures_to_cleanup = []
 1448         for service in default_fixtures.SERVICES:
 1449             service_id = service['id']
 1450             rv = PROVIDERS.catalog_api.create_service(service_id, service)
 1451             attrname = service['extra']['name']
 1452             setattr(self, attrname, rv)
 1453             fixtures_to_cleanup.append(attrname)
 1454         for region in default_fixtures.REGIONS:
 1455             rv = PROVIDERS.catalog_api.create_region(region)
 1456             attrname = region['id']
 1457             setattr(self, attrname, rv)
 1458             fixtures_to_cleanup.append(attrname)
 1459         self.addCleanup(self.cleanup_instance(*fixtures_to_cleanup))
 1460 
 1461 
 1462 class SqlLimit(SqlTests, limit_tests.LimitTests):
 1463 
 1464     def setUp(self):
 1465         super(SqlLimit, self).setUp()
 1466 
 1467         fixtures_to_cleanup = []
 1468         for service in default_fixtures.SERVICES:
 1469             service_id = service['id']
 1470             rv = PROVIDERS.catalog_api.create_service(service_id, service)
 1471             attrname = service['extra']['name']
 1472             setattr(self, attrname, rv)
 1473             fixtures_to_cleanup.append(attrname)
 1474         for region in default_fixtures.REGIONS:
 1475             rv = PROVIDERS.catalog_api.create_region(region)
 1476             attrname = region['id']
 1477             setattr(self, attrname, rv)
 1478             fixtures_to_cleanup.append(attrname)
 1479         self.addCleanup(self.cleanup_instance(*fixtures_to_cleanup))
 1480 
 1481         registered_limit_1 = unit.new_registered_limit_ref(
 1482             service_id=self.service_one['id'],
 1483             region_id=self.region_one['id'],
 1484             resource_name='volume', default_limit=10, id=uuid.uuid4().hex)
 1485         registered_limit_2 = unit.new_registered_limit_ref(
 1486             service_id=self.service_one['id'],
 1487             region_id=self.region_two['id'],
 1488             resource_name='snapshot', default_limit=10, id=uuid.uuid4().hex)
 1489         registered_limit_3 = unit.new_registered_limit_ref(
 1490             service_id=self.service_one['id'],
 1491             region_id=self.region_two['id'],
 1492             resource_name='backup', default_limit=10, id=uuid.uuid4().hex)
 1493         PROVIDERS.unified_limit_api.create_registered_limits(
 1494             [registered_limit_1, registered_limit_2, registered_limit_3])