"Fossies" - the Fresh Open Source Software Archive

Member "keystone-16.0.2/keystone/tests/unit/test_backend_sql.py" (7 Jun 2021, 56810 Bytes) of package /linux/misc/openstack/keystone-16.0.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. See also the latest Fossies "Diffs" side-by-side code changes report for "test_backend_sql.py": 16.0.1_vs_16.0.2.

    1 # Copyright 2012 OpenStack Foundation
    2 #
    3 # Licensed under the Apache License, Version 2.0 (the "License"); you may
    4 # not use this file except in compliance with the License. You may obtain
    5 # a copy of the License at
    6 #
    7 #      http://www.apache.org/licenses/LICENSE-2.0
    8 #
    9 # Unless required by applicable law or agreed to in writing, software
   10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   12 # License for the specific language governing permissions and limitations
   13 # under the License.
   14 
   15 import datetime
   16 import uuid
   17 
   18 import fixtures
   19 import mock
   20 from oslo_db import exception as db_exception
   21 from oslo_db import options
   22 from oslo_log import log
   23 from six.moves import range
   24 import sqlalchemy
   25 from sqlalchemy import exc
   26 from testtools import matchers
   27 
   28 from keystone.common import driver_hints
   29 from keystone.common import provider_api
   30 from keystone.common import sql
   31 from keystone.common.sql import core
   32 import keystone.conf
   33 from keystone.credential.providers import fernet as credential_provider
   34 from keystone import exception
   35 from keystone.identity.backends import sql_model as identity_sql
   36 from keystone.resource.backends import base as resource
   37 from keystone.tests import unit
   38 from keystone.tests.unit.assignment import test_backends as assignment_tests
   39 from keystone.tests.unit.catalog import test_backends as catalog_tests
   40 from keystone.tests.unit import default_fixtures
   41 from keystone.tests.unit.identity import test_backends as identity_tests
   42 from keystone.tests.unit import ksfixtures
   43 from keystone.tests.unit.ksfixtures import database
   44 from keystone.tests.unit.limit import test_backends as limit_tests
   45 from keystone.tests.unit.policy import test_backends as policy_tests
   46 from keystone.tests.unit.resource import test_backends as resource_tests
   47 from keystone.tests.unit.trust import test_backends as trust_tests
   48 from keystone.trust.backends import sql as trust_sql
   49 
   50 
   51 CONF = keystone.conf.CONF
   52 PROVIDERS = provider_api.ProviderAPIs
   53 
   54 
   55 class SqlTests(unit.SQLDriverOverrides, unit.TestCase):
   56 
   57     def setUp(self):
   58         super(SqlTests, self).setUp()
   59         self.useFixture(database.Database())
   60         self.load_backends()
   61 
   62         # populate the engine with tables & fixtures
   63         self.load_fixtures(default_fixtures)
   64         # defaulted by the data load
   65         self.user_foo['enabled'] = True
   66 
   67     def config_files(self):
   68         config_files = super(SqlTests, self).config_files()
   69         config_files.append(unit.dirs.tests_conf('backend_sql.conf'))
   70         return config_files
   71 
   72 
   73 class DataTypeRoundTrips(SqlTests):
   74     def test_json_blob_roundtrip(self):
   75         """Test round-trip of a JSON data structure with JsonBlob."""
   76         with sql.session_for_read() as session:
   77             val = session.scalar(
   78                 sqlalchemy.select(
   79                     [sqlalchemy.literal({"key": "value"}, type_=core.JsonBlob)]
   80                 )
   81             )
   82 
   83         self.assertEqual({"key": "value"}, val)
   84 
   85     def test_json_blob_sql_null(self):
   86         """Test that JsonBlob can accommodate a SQL NULL value in a result set.
   87 
   88         SQL NULL may be handled by JsonBlob in the case where a table is
   89         storing NULL in a JsonBlob column, as several models use this type
   90         in a column that is nullable.   It also comes back when the column
   91         is left NULL from being in an OUTER JOIN.  In Python, this means
   92         the None constant is handled by the datatype.
   93 
   94         """
   95         with sql.session_for_read() as session:
   96             val = session.scalar(
   97                 sqlalchemy.select(
   98                     [sqlalchemy.cast(sqlalchemy.null(), type_=core.JsonBlob)]
   99                 )
  100             )
  101 
  102         self.assertIsNone(val)
  103 
  104     def test_json_blob_python_none(self):
  105         """Test that JsonBlob round-trips a Python None.
  106 
  107         This is where JSON datatypes get a little nutty, in that JSON has
  108         a 'null' keyword, and JsonBlob right now will persist Python None
  109         as the json string 'null', not SQL NULL.
  110 
  111         """
  112         with sql.session_for_read() as session:
  113             val = session.scalar(
  114                 sqlalchemy.select(
  115                     [sqlalchemy.literal(None, type_=core.JsonBlob)]
  116                 )
  117             )
  118 
  119         self.assertIsNone(val)
  120 
  121     def test_json_blob_python_none_renders(self):
  122         """Test that JsonBlob actually renders JSON 'null' for Python None."""
  123         with sql.session_for_read() as session:
  124             val = session.scalar(
  125                 sqlalchemy.select(
  126                     [
  127                         sqlalchemy.cast(
  128                             sqlalchemy.literal(None, type_=core.JsonBlob),
  129                             sqlalchemy.String,
  130                         )
  131                     ]
  132                 )
  133             )
  134 
  135         self.assertEqual("null", val)
  136 
  137     def test_datetimeint_roundtrip(self):
  138         """Test round-trip of a Python datetime with DateTimeInt."""
  139         with sql.session_for_read() as session:
  140             datetime_value = datetime.datetime(2019, 5, 15, 10, 17, 55)
  141             val = session.scalar(
  142                 sqlalchemy.select(
  143                     [
  144                         sqlalchemy.literal(
  145                             datetime_value, type_=core.DateTimeInt
  146                         ),
  147                     ]
  148                 )
  149             )
  150 
  151         self.assertEqual(datetime_value, val)
  152 
  153     def test_datetimeint_persistence(self):
  154         """Test integer persistence with DateTimeInt."""
  155         with sql.session_for_read() as session:
  156             datetime_value = datetime.datetime(2019, 5, 15, 10, 17, 55)
  157             val = session.scalar(
  158                 sqlalchemy.select(
  159                     [
  160                         sqlalchemy.cast(
  161                             sqlalchemy.literal(
  162                                 datetime_value, type_=core.DateTimeInt
  163                             ),
  164                             sqlalchemy.Integer
  165                         )
  166                     ]
  167                 )
  168             )
  169 
  170         self.assertEqual(1557915475000000, val)
  171 
  172     def test_datetimeint_python_none(self):
  173         """Test round-trip of a Python None with DateTimeInt."""
  174         with sql.session_for_read() as session:
  175             val = session.scalar(
  176                 sqlalchemy.select(
  177                     [
  178                         sqlalchemy.literal(None, type_=core.DateTimeInt),
  179                     ]
  180                 )
  181             )
  182 
  183         self.assertIsNone(val)
  184 
  185 
  186 class SqlModels(SqlTests):
  187 
  188     def load_table(self, name):
  189         table = sqlalchemy.Table(name,
  190                                  sql.ModelBase.metadata,
  191                                  autoload=True)
  192         return table
  193 
  194     def assertExpectedSchema(self, table, expected_schema):
  195         """Assert that a table's schema is what we expect.
  196 
  197         :param string table: the name of the table to inspect
  198         :param tuple expected_schema: a tuple of tuples containing the
  199             expected schema
  200         :raises AssertionError: when the database schema doesn't match the
  201             expected schema
  202 
  203         The expected_schema format is simply::
  204 
  205             (
  206                 ('column name', sql type, qualifying detail),
  207                 ...
  208             )
  209 
  210         The qualifying detail varies based on the type of the column::
  211 
  212           - sql.Boolean columns must indicate the column's default value or
  213             None if there is no default
  214           - Columns with a length, like sql.String, must indicate the
  215             column's length
  216           - All other column types should use None
  217 
  218         Example::
  219 
  220             cols = (('id', sql.String, 64),
  221                     ('enabled', sql.Boolean, True),
  222                     ('extra', sql.JsonBlob, None))
  223             self.assertExpectedSchema('table_name', cols)
  224 
  225         """
  226         table = self.load_table(table)
  227 
  228         actual_schema = []
  229         for column in table.c:
  230             if isinstance(column.type, sql.Boolean):
  231                 default = None
  232                 if column.default:
  233                     default = column.default.arg
  234                 actual_schema.append((column.name, type(column.type), default))
  235             elif (hasattr(column.type, 'length') and
  236                     not isinstance(column.type, sql.Enum)):
  237                 # NOTE(dstanek): Even though sql.Enum columns have a length
  238                 # set we don't want to catch them here. Maybe in the future
  239                 # we'll check to see that they contain a list of the correct
  240                 # possible values.
  241                 actual_schema.append((column.name,
  242                                       type(column.type),
  243                                       column.type.length))
  244             else:
  245                 actual_schema.append((column.name, type(column.type), None))
  246 
  247         self.assertItemsEqual(expected_schema, actual_schema)
  248 
  249     def test_user_model(self):
  250         cols = (('id', sql.String, 64),
  251                 ('domain_id', sql.String, 64),
  252                 ('default_project_id', sql.String, 64),
  253                 ('enabled', sql.Boolean, None),
  254                 ('extra', sql.JsonBlob, None),
  255                 ('created_at', sql.DateTime, None),
  256                 ('last_active_at', sqlalchemy.Date, None))
  257         self.assertExpectedSchema('user', cols)
  258 
  259     def test_local_user_model(self):
  260         cols = (('id', sql.Integer, None),
  261                 ('user_id', sql.String, 64),
  262                 ('name', sql.String, 255),
  263                 ('domain_id', sql.String, 64),
  264                 ('failed_auth_count', sql.Integer, None),
  265                 ('failed_auth_at', sql.DateTime, None))
  266         self.assertExpectedSchema('local_user', cols)
  267 
  268     def test_password_model(self):
  269         cols = (('id', sql.Integer, None),
  270                 ('local_user_id', sql.Integer, None),
  271                 ('password_hash', sql.String, 255),
  272                 ('created_at', sql.DateTime, None),
  273                 ('expires_at', sql.DateTime, None),
  274                 ('created_at_int', sql.DateTimeInt, None),
  275                 ('expires_at_int', sql.DateTimeInt, None),
  276                 ('self_service', sql.Boolean, False))
  277         self.assertExpectedSchema('password', cols)
  278 
  279     def test_federated_user_model(self):
  280         cols = (('id', sql.Integer, None),
  281                 ('user_id', sql.String, 64),
  282                 ('idp_id', sql.String, 64),
  283                 ('protocol_id', sql.String, 64),
  284                 ('unique_id', sql.String, 255),
  285                 ('display_name', sql.String, 255))
  286         self.assertExpectedSchema('federated_user', cols)
  287 
  288     def test_nonlocal_user_model(self):
  289         cols = (('domain_id', sql.String, 64),
  290                 ('name', sql.String, 255),
  291                 ('user_id', sql.String, 64))
  292         self.assertExpectedSchema('nonlocal_user', cols)
  293 
  294     def test_group_model(self):
  295         cols = (('id', sql.String, 64),
  296                 ('name', sql.String, 64),
  297                 ('description', sql.Text, None),
  298                 ('domain_id', sql.String, 64),
  299                 ('extra', sql.JsonBlob, None))
  300         self.assertExpectedSchema('group', cols)
  301 
  302     def test_project_model(self):
  303         cols = (('id', sql.String, 64),
  304                 ('name', sql.String, 64),
  305                 ('description', sql.Text, None),
  306                 ('domain_id', sql.String, 64),
  307                 ('enabled', sql.Boolean, None),
  308                 ('extra', sql.JsonBlob, None),
  309                 ('parent_id', sql.String, 64),
  310                 ('is_domain', sql.Boolean, False))
  311         self.assertExpectedSchema('project', cols)
  312 
  313     def test_role_assignment_model(self):
  314         cols = (('type', sql.Enum, None),
  315                 ('actor_id', sql.String, 64),
  316                 ('target_id', sql.String, 64),
  317                 ('role_id', sql.String, 64),
  318                 ('inherited', sql.Boolean, False))
  319         self.assertExpectedSchema('assignment', cols)
  320 
  321     def test_user_group_membership(self):
  322         cols = (('group_id', sql.String, 64),
  323                 ('user_id', sql.String, 64))
  324         self.assertExpectedSchema('user_group_membership', cols)
  325 
  326     def test_revocation_event_model(self):
  327         cols = (('id', sql.Integer, None),
  328                 ('domain_id', sql.String, 64),
  329                 ('project_id', sql.String, 64),
  330                 ('user_id', sql.String, 64),
  331                 ('role_id', sql.String, 64),
  332                 ('trust_id', sql.String, 64),
  333                 ('consumer_id', sql.String, 64),
  334                 ('access_token_id', sql.String, 64),
  335                 ('issued_before', sql.DateTime, None),
  336                 ('expires_at', sql.DateTime, None),
  337                 ('revoked_at', sql.DateTime, None),
  338                 ('audit_id', sql.String, 32),
  339                 ('audit_chain_id', sql.String, 32))
  340         self.assertExpectedSchema('revocation_event', cols)
  341 
  342     def test_project_tags_model(self):
  343         cols = (('project_id', sql.String, 64),
  344                 ('name', sql.Unicode, 255))
  345         self.assertExpectedSchema('project_tag', cols)
  346 
  347 
  348 class SqlIdentity(SqlTests,
  349                   identity_tests.IdentityTests,
  350                   assignment_tests.AssignmentTests,
  351                   assignment_tests.SystemAssignmentTests,
  352                   resource_tests.ResourceTests):
  353     def test_password_hashed(self):
  354         with sql.session_for_read() as session:
  355             user_ref = PROVIDERS.identity_api._get_user(
  356                 session, self.user_foo['id']
  357             )
  358             self.assertNotEqual(self.user_foo['password'],
  359                                 user_ref['password'])
  360 
  361     def test_create_user_with_null_password(self):
  362         user_dict = unit.new_user_ref(
  363             domain_id=CONF.identity.default_domain_id)
  364         user_dict["password"] = None
  365         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  366         with sql.session_for_read() as session:
  367             new_user_ref = PROVIDERS.identity_api._get_user(
  368                 session, new_user_dict['id']
  369             )
  370             self.assertIsNone(new_user_ref.password)
  371 
  372     def test_update_user_with_null_password(self):
  373         user_dict = unit.new_user_ref(
  374             domain_id=CONF.identity.default_domain_id)
  375         self.assertTrue(user_dict['password'])
  376         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  377         new_user_dict["password"] = None
  378         new_user_dict = PROVIDERS.identity_api.update_user(
  379             new_user_dict['id'], new_user_dict
  380         )
  381         with sql.session_for_read() as session:
  382             new_user_ref = PROVIDERS.identity_api._get_user(
  383                 session, new_user_dict['id']
  384             )
  385             self.assertIsNone(new_user_ref.password)
  386 
  387     def test_delete_user_with_project_association(self):
  388         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  389         user = PROVIDERS.identity_api.create_user(user)
  390         role_member = unit.new_role_ref()
  391         PROVIDERS.role_api.create_role(role_member['id'], role_member)
  392         PROVIDERS.assignment_api.add_role_to_user_and_project(
  393             user['id'], self.project_bar['id'], role_member['id']
  394         )
  395         PROVIDERS.identity_api.delete_user(user['id'])
  396         self.assertRaises(exception.UserNotFound,
  397                           PROVIDERS.assignment_api.list_projects_for_user,
  398                           user['id'])
  399 
  400     def test_create_user_case_sensitivity(self):
  401         # user name case sensitivity is down to the fact that it is marked as
  402         # an SQL UNIQUE column, which may not be valid for other backends, like
  403         # LDAP.
  404 
  405         # create a ref with a lowercase name
  406         ref = unit.new_user_ref(name=uuid.uuid4().hex.lower(),
  407                                 domain_id=CONF.identity.default_domain_id)
  408         ref = PROVIDERS.identity_api.create_user(ref)
  409 
  410         # assign a new ID with the same name, but this time in uppercase
  411         ref['name'] = ref['name'].upper()
  412         PROVIDERS.identity_api.create_user(ref)
  413 
  414     def test_create_project_case_sensitivity(self):
  415         # project name case sensitivity is down to the fact that it is marked
  416         # as an SQL UNIQUE column, which may not be valid for other backends,
  417         # like LDAP.
  418 
  419         # create a ref with a lowercase name
  420         ref = unit.new_project_ref(domain_id=CONF.identity.default_domain_id)
  421         PROVIDERS.resource_api.create_project(ref['id'], ref)
  422 
  423         # assign a new ID with the same name, but this time in uppercase
  424         ref['id'] = uuid.uuid4().hex
  425         ref['name'] = ref['name'].upper()
  426         PROVIDERS.resource_api.create_project(ref['id'], ref)
  427 
  428     def test_delete_project_with_user_association(self):
  429         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  430         user = PROVIDERS.identity_api.create_user(user)
  431         role_member = unit.new_role_ref()
  432         PROVIDERS.role_api.create_role(role_member['id'], role_member)
  433         PROVIDERS.assignment_api.add_role_to_user_and_project(
  434             user['id'], self.project_bar['id'], role_member['id']
  435         )
  436         PROVIDERS.resource_api.delete_project(self.project_bar['id'])
  437         projects = PROVIDERS.assignment_api.list_projects_for_user(user['id'])
  438         self.assertEqual([], projects)
  439 
  440     def test_update_project_returns_extra(self):
  441         """Test for backward compatibility with an essex/folsom bug.
  442 
  443         Non-indexed attributes were returned in an 'extra' attribute, instead
  444         of on the entity itself; for consistency and backwards compatibility,
  445         those attributes should be included twice.
  446 
  447         This behavior is specific to the SQL driver.
  448 
  449         """
  450         arbitrary_key = uuid.uuid4().hex
  451         arbitrary_value = uuid.uuid4().hex
  452         project = unit.new_project_ref(
  453             domain_id=CONF.identity.default_domain_id)
  454         project[arbitrary_key] = arbitrary_value
  455         ref = PROVIDERS.resource_api.create_project(project['id'], project)
  456         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  457         self.assertNotIn('extra', ref)
  458 
  459         ref['name'] = uuid.uuid4().hex
  460         ref = PROVIDERS.resource_api.update_project(ref['id'], ref)
  461         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  462         self.assertEqual(arbitrary_value, ref['extra'][arbitrary_key])
  463 
  464     def test_update_user_returns_extra(self):
  465         """Test for backwards-compatibility with an essex/folsom bug.
  466 
  467         Non-indexed attributes were returned in an 'extra' attribute, instead
  468         of on the entity itself; for consistency and backwards compatibility,
  469         those attributes should be included twice.
  470 
  471         This behavior is specific to the SQL driver.
  472 
  473         """
  474         arbitrary_key = uuid.uuid4().hex
  475         arbitrary_value = uuid.uuid4().hex
  476         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  477         user[arbitrary_key] = arbitrary_value
  478         del user["id"]
  479         ref = PROVIDERS.identity_api.create_user(user)
  480         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  481         self.assertNotIn('password', ref)
  482         self.assertNotIn('extra', ref)
  483 
  484         user['name'] = uuid.uuid4().hex
  485         user['password'] = uuid.uuid4().hex
  486         ref = PROVIDERS.identity_api.update_user(ref['id'], user)
  487         self.assertNotIn('password', ref)
  488         self.assertNotIn('password', ref['extra'])
  489         self.assertEqual(arbitrary_value, ref[arbitrary_key])
  490         self.assertEqual(arbitrary_value, ref['extra'][arbitrary_key])
  491 
  492     def test_sql_user_to_dict_null_default_project_id(self):
  493         user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  494         user = PROVIDERS.identity_api.create_user(user)
  495         with sql.session_for_read() as session:
  496             query = session.query(identity_sql.User)
  497             query = query.filter_by(id=user['id'])
  498             raw_user_ref = query.one()
  499             self.assertIsNone(raw_user_ref.default_project_id)
  500             user_ref = raw_user_ref.to_dict()
  501             self.assertNotIn('default_project_id', user_ref)
  502             session.close()
  503 
  504     def test_list_domains_for_user(self):
  505         domain = unit.new_domain_ref()
  506         PROVIDERS.resource_api.create_domain(domain['id'], domain)
  507         user = unit.new_user_ref(domain_id=domain['id'])
  508 
  509         test_domain1 = unit.new_domain_ref()
  510         PROVIDERS.resource_api.create_domain(test_domain1['id'], test_domain1)
  511         test_domain2 = unit.new_domain_ref()
  512         PROVIDERS.resource_api.create_domain(test_domain2['id'], test_domain2)
  513 
  514         user = PROVIDERS.identity_api.create_user(user)
  515         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  516             user['id']
  517         )
  518         self.assertEqual(0, len(user_domains))
  519         PROVIDERS.assignment_api.create_grant(
  520             user_id=user['id'], domain_id=test_domain1['id'],
  521             role_id=self.role_member['id']
  522         )
  523         PROVIDERS.assignment_api.create_grant(
  524             user_id=user['id'], domain_id=test_domain2['id'],
  525             role_id=self.role_member['id']
  526         )
  527         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  528             user['id']
  529         )
  530         self.assertThat(user_domains, matchers.HasLength(2))
  531 
  532     def test_list_domains_for_user_with_grants(self):
  533         # Create two groups each with a role on a different domain, and
  534         # make user1 a member of both groups.  Both these new domains
  535         # should now be included, along with any direct user grants.
  536         domain = unit.new_domain_ref()
  537         PROVIDERS.resource_api.create_domain(domain['id'], domain)
  538         user = unit.new_user_ref(domain_id=domain['id'])
  539         user = PROVIDERS.identity_api.create_user(user)
  540         group1 = unit.new_group_ref(domain_id=domain['id'])
  541         group1 = PROVIDERS.identity_api.create_group(group1)
  542         group2 = unit.new_group_ref(domain_id=domain['id'])
  543         group2 = PROVIDERS.identity_api.create_group(group2)
  544 
  545         test_domain1 = unit.new_domain_ref()
  546         PROVIDERS.resource_api.create_domain(test_domain1['id'], test_domain1)
  547         test_domain2 = unit.new_domain_ref()
  548         PROVIDERS.resource_api.create_domain(test_domain2['id'], test_domain2)
  549         test_domain3 = unit.new_domain_ref()
  550         PROVIDERS.resource_api.create_domain(test_domain3['id'], test_domain3)
  551 
  552         PROVIDERS.identity_api.add_user_to_group(user['id'], group1['id'])
  553         PROVIDERS.identity_api.add_user_to_group(user['id'], group2['id'])
  554 
  555         # Create 3 grants, one user grant, the other two as group grants
  556         PROVIDERS.assignment_api.create_grant(
  557             user_id=user['id'], domain_id=test_domain1['id'],
  558             role_id=self.role_member['id']
  559         )
  560         PROVIDERS.assignment_api.create_grant(
  561             group_id=group1['id'], domain_id=test_domain2['id'],
  562             role_id=self.role_admin['id']
  563         )
  564         PROVIDERS.assignment_api.create_grant(
  565             group_id=group2['id'], domain_id=test_domain3['id'],
  566             role_id=self.role_admin['id']
  567         )
  568         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  569             user['id']
  570         )
  571         self.assertThat(user_domains, matchers.HasLength(3))
  572 
  573     def test_list_domains_for_user_with_inherited_grants(self):
  574         """Test that inherited roles on the domain are excluded.
  575 
  576         Test Plan:
  577 
  578         - Create two domains, one user, group and role
  579         - Domain1 is given an inherited user role, Domain2 an inherited
  580           group role (for a group of which the user is a member)
  581         - When listing domains for user, neither domain should be returned
  582 
  583         """
  584         domain1 = unit.new_domain_ref()
  585         domain1 = PROVIDERS.resource_api.create_domain(domain1['id'], domain1)
  586         domain2 = unit.new_domain_ref()
  587         domain2 = PROVIDERS.resource_api.create_domain(domain2['id'], domain2)
  588         user = unit.new_user_ref(domain_id=domain1['id'])
  589         user = PROVIDERS.identity_api.create_user(user)
  590         group = unit.new_group_ref(domain_id=domain1['id'])
  591         group = PROVIDERS.identity_api.create_group(group)
  592         PROVIDERS.identity_api.add_user_to_group(user['id'], group['id'])
  593         role = unit.new_role_ref()
  594         PROVIDERS.role_api.create_role(role['id'], role)
  595 
  596         # Create a grant on each domain, one user grant, one group grant,
  597         # both inherited.
  598         PROVIDERS.assignment_api.create_grant(
  599             user_id=user['id'], domain_id=domain1['id'], role_id=role['id'],
  600             inherited_to_projects=True
  601         )
  602         PROVIDERS.assignment_api.create_grant(
  603             group_id=group['id'], domain_id=domain2['id'], role_id=role['id'],
  604             inherited_to_projects=True
  605         )
  606 
  607         user_domains = PROVIDERS.assignment_api.list_domains_for_user(
  608             user['id']
  609         )
  610         # No domains should be returned since both domains have only inherited
  611         # roles assignments.
  612         self.assertThat(user_domains, matchers.HasLength(0))
  613 
  614     def test_list_groups_for_user(self):
  615         domain = self._get_domain_fixture()
  616         test_groups = []
  617         test_users = []
  618         GROUP_COUNT = 3
  619         USER_COUNT = 2
  620 
  621         for x in range(0, USER_COUNT):
  622             new_user = unit.new_user_ref(domain_id=domain['id'])
  623             new_user = PROVIDERS.identity_api.create_user(new_user)
  624             test_users.append(new_user)
  625         positive_user = test_users[0]
  626         negative_user = test_users[1]
  627 
  628         for x in range(0, USER_COUNT):
  629             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  630                 test_users[x]['id'])
  631             self.assertEqual(0, len(group_refs))
  632 
  633         for x in range(0, GROUP_COUNT):
  634             before_count = x
  635             after_count = x + 1
  636             new_group = unit.new_group_ref(domain_id=domain['id'])
  637             new_group = PROVIDERS.identity_api.create_group(new_group)
  638             test_groups.append(new_group)
  639 
  640             # add the user to the group and ensure that the
  641             # group count increases by one for each
  642             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  643                 positive_user['id'])
  644             self.assertEqual(before_count, len(group_refs))
  645             PROVIDERS.identity_api.add_user_to_group(
  646                 positive_user['id'],
  647                 new_group['id'])
  648             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  649                 positive_user['id'])
  650             self.assertEqual(after_count, len(group_refs))
  651 
  652             # Make sure the group count for the unrelated user did not change
  653             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  654                 negative_user['id'])
  655             self.assertEqual(0, len(group_refs))
  656 
  657         # remove the user from each group and ensure that
  658         # the group count reduces by one for each
  659         for x in range(0, 3):
  660             before_count = GROUP_COUNT - x
  661             after_count = GROUP_COUNT - x - 1
  662             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  663                 positive_user['id'])
  664             self.assertEqual(before_count, len(group_refs))
  665             PROVIDERS.identity_api.remove_user_from_group(
  666                 positive_user['id'],
  667                 test_groups[x]['id'])
  668             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  669                 positive_user['id'])
  670             self.assertEqual(after_count, len(group_refs))
  671             # Make sure the group count for the unrelated user
  672             # did not change
  673             group_refs = PROVIDERS.identity_api.list_groups_for_user(
  674                 negative_user['id'])
  675             self.assertEqual(0, len(group_refs))
  676 
  677     def test_storing_null_domain_id_in_project_ref(self):
  678         """Test the special storage of domain_id=None in sql resource driver.
  679 
  680         The resource driver uses a special value in place of None for domain_id
  681         in the project record. This shouldn't escape the driver. Hence we test
  682         the interface to ensure that you can store a domain_id of None, and
  683         that any special value used inside the driver does not escape through
  684         the interface.
  685 
  686         """
  687         spoiler_project = unit.new_project_ref(
  688             domain_id=CONF.identity.default_domain_id)
  689         PROVIDERS.resource_api.create_project(
  690             spoiler_project['id'], spoiler_project
  691         )
  692 
  693         # First let's create a project with a None domain_id and make sure we
  694         # can read it back.
  695         project = unit.new_project_ref(domain_id=None, is_domain=True)
  696         project = PROVIDERS.resource_api.create_project(project['id'], project)
  697         ref = PROVIDERS.resource_api.get_project(project['id'])
  698         self.assertDictEqual(project, ref)
  699 
  700         # Can we get it by name?
  701         ref = PROVIDERS.resource_api.get_project_by_name(project['name'], None)
  702         self.assertDictEqual(project, ref)
  703 
  704         # Can we filter for them - create a second domain to ensure we are
  705         # testing the receipt of more than one.
  706         project2 = unit.new_project_ref(domain_id=None, is_domain=True)
  707         project2 = PROVIDERS.resource_api.create_project(
  708             project2['id'], project2
  709         )
  710         hints = driver_hints.Hints()
  711         hints.add_filter('domain_id', None)
  712         refs = PROVIDERS.resource_api.list_projects(hints)
  713         self.assertThat(refs, matchers.HasLength(2 + self.domain_count))
  714         self.assertIn(project, refs)
  715         self.assertIn(project2, refs)
  716 
  717         # Can we update it?
  718         project['name'] = uuid.uuid4().hex
  719         PROVIDERS.resource_api.update_project(project['id'], project)
  720         ref = PROVIDERS.resource_api.get_project(project['id'])
  721         self.assertDictEqual(project, ref)
  722 
  723         # Finally, make sure we can delete it
  724         project['enabled'] = False
  725         PROVIDERS.resource_api.update_project(project['id'], project)
  726         PROVIDERS.resource_api.delete_project(project['id'])
  727         self.assertRaises(exception.ProjectNotFound,
  728                           PROVIDERS.resource_api.get_project,
  729                           project['id'])
  730 
  731     def test_hidden_project_domain_root_is_really_hidden(self):
  732         """Ensure we cannot access the hidden root of all project domains.
  733 
  734         Calling any of the driver methods should result in the same as
  735         would be returned if we passed a project that does not exist. We don't
  736         test create_project, since we do not allow a caller of our API to
  737         specify their own ID for a new entity.
  738 
  739         """
  740         def _exercise_project_api(ref_id):
  741             driver = PROVIDERS.resource_api.driver
  742             self.assertRaises(exception.ProjectNotFound,
  743                               driver.get_project,
  744                               ref_id)
  745 
  746             self.assertRaises(exception.ProjectNotFound,
  747                               driver.get_project_by_name,
  748                               resource.NULL_DOMAIN_ID,
  749                               ref_id)
  750 
  751             project_ids = [x['id'] for x in
  752                            driver.list_projects(driver_hints.Hints())]
  753             self.assertNotIn(ref_id, project_ids)
  754 
  755             projects = driver.list_projects_from_ids([ref_id])
  756             self.assertThat(projects, matchers.HasLength(0))
  757 
  758             project_ids = [x for x in
  759                            driver.list_project_ids_from_domain_ids([ref_id])]
  760             self.assertNotIn(ref_id, project_ids)
  761 
  762             self.assertRaises(exception.DomainNotFound,
  763                               driver.list_projects_in_domain,
  764                               ref_id)
  765 
  766             project_ids = [
  767                 x['id'] for x in
  768                 driver.list_projects_acting_as_domain(driver_hints.Hints())]
  769             self.assertNotIn(ref_id, project_ids)
  770 
  771             projects = driver.list_projects_in_subtree(ref_id)
  772             self.assertThat(projects, matchers.HasLength(0))
  773 
  774             self.assertRaises(exception.ProjectNotFound,
  775                               driver.list_project_parents,
  776                               ref_id)
  777 
  778             # A non-existing project just returns True from the driver
  779             self.assertTrue(driver.is_leaf_project(ref_id))
  780 
  781             self.assertRaises(exception.ProjectNotFound,
  782                               driver.update_project,
  783                               ref_id,
  784                               {})
  785 
  786             self.assertRaises(exception.ProjectNotFound,
  787                               driver.delete_project,
  788                               ref_id)
  789 
  790             # Deleting list of projects that includes a non-existing project
  791             # should be silent. The root domain <<keystone.domain.root>> can't
  792             # be deleted.
  793             if ref_id != resource.NULL_DOMAIN_ID:
  794                 driver.delete_projects_from_ids([ref_id])
  795 
  796         _exercise_project_api(uuid.uuid4().hex)
  797         _exercise_project_api(resource.NULL_DOMAIN_ID)
  798 
  799     def test_list_users_call_count(self):
  800         """There should not be O(N) queries."""
  801         # create 10 users. 10 is just a random number
  802         for i in range(10):
  803             user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  804             PROVIDERS.identity_api.create_user(user)
  805 
  806         # sqlalchemy emits various events and allows to listen to them. Here
  807         # bound method `query_counter` will be called each time when a query
  808         # is compiled
  809         class CallCounter(object):
  810             def __init__(self):
  811                 self.calls = 0
  812 
  813             def reset(self):
  814                 self.calls = 0
  815 
  816             def query_counter(self, query):
  817                 self.calls += 1
  818 
  819         counter = CallCounter()
  820         sqlalchemy.event.listen(sqlalchemy.orm.query.Query, 'before_compile',
  821                                 counter.query_counter)
  822 
  823         first_call_users = PROVIDERS.identity_api.list_users()
  824         first_call_counter = counter.calls
  825         # add 10 more users
  826         for i in range(10):
  827             user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
  828             PROVIDERS.identity_api.create_user(user)
  829         counter.reset()
  830         second_call_users = PROVIDERS.identity_api.list_users()
  831         # ensure that the number of calls does not depend on the number of
  832         # users fetched.
  833         self.assertNotEqual(len(first_call_users), len(second_call_users))
  834         self.assertEqual(first_call_counter, counter.calls)
  835         self.assertEqual(3, counter.calls)
  836 
  837     def test_check_project_depth(self):
  838         # Create a 3 level project tree:
  839         #
  840         # default_domain
  841         #       |
  842         #   project_1
  843         #       |
  844         #   project_2
  845         project_1 = unit.new_project_ref(
  846             domain_id=CONF.identity.default_domain_id)
  847         PROVIDERS.resource_api.create_project(project_1['id'], project_1)
  848         project_2 = unit.new_project_ref(
  849             domain_id=CONF.identity.default_domain_id,
  850             parent_id=project_1['id'])
  851         PROVIDERS.resource_api.create_project(project_2['id'], project_2)
  852 
  853         # if max_depth is None or >= current project depth, return nothing.
  854         resp = PROVIDERS.resource_api.check_project_depth(max_depth=None)
  855         self.assertIsNone(resp)
  856         resp = PROVIDERS.resource_api.check_project_depth(max_depth=3)
  857         self.assertIsNone(resp)
  858         resp = PROVIDERS.resource_api.check_project_depth(max_depth=4)
  859         self.assertIsNone(resp)
  860         # if max_depth < current project depth, raise LimitTreeExceedError
  861         self.assertRaises(exception.LimitTreeExceedError,
  862                           PROVIDERS.resource_api.check_project_depth,
  863                           2)
  864 
  865     def test_update_user_with_stale_data_forces_retry(self):
  866         # Capture log output so we know oslo.db attempted a retry
  867         log_fixture = self.useFixture(fixtures.FakeLogger(level=log.DEBUG))
  868 
  869         # Create a new user
  870         user_dict = unit.new_user_ref(
  871             domain_id=CONF.identity.default_domain_id)
  872         new_user_dict = PROVIDERS.identity_api.create_user(user_dict)
  873 
  874         side_effects = [
  875             # Raise a StaleDataError simulating that another client has
  876             # updated the user's password while this client's request was
  877             # being processed
  878             sqlalchemy.orm.exc.StaleDataError,
  879             # The oslo.db library will retry the request, so the second
  880             # time this method is called let's return a valid session
  881             # object
  882             sql.session_for_write()
  883         ]
  884         with mock.patch('keystone.common.sql.session_for_write') as m:
  885             m.side_effect = side_effects
  886 
  887             # Update a user's attribute, the first attempt will fail but
  888             # oslo.db will handle the exception and retry, the second attempt
  889             # will succeed
  890             new_user_dict['email'] = uuid.uuid4().hex
  891             PROVIDERS.identity_api.update_user(
  892                 new_user_dict['id'], new_user_dict)
  893 
  894         # Make sure oslo.db retried the update by checking the log output
  895         expected_log_message = (
  896             'Performing DB retry for function keystone.identity.backends.sql'
  897         )
  898         self.assertIn(expected_log_message, log_fixture.output)
  899 
  900 
  901 class SqlTrust(SqlTests, trust_tests.TrustTests):
  902 
  903     def test_trust_expires_at_int_matches_expires_at(self):
  904         with sql.session_for_write() as session:
  905             new_id = uuid.uuid4().hex
  906             self.create_sample_trust(new_id)
  907             trust_ref = session.query(trust_sql.TrustModel).get(new_id)
  908             self.assertIsNotNone(trust_ref._expires_at)
  909             self.assertEqual(trust_ref._expires_at, trust_ref.expires_at_int)
  910             self.assertEqual(trust_ref.expires_at, trust_ref.expires_at_int)
  911 
  912 
  913 class SqlCatalog(SqlTests, catalog_tests.CatalogTests):
  914 
  915     _legacy_endpoint_id_in_endpoint = True
  916     _enabled_default_to_true_when_creating_endpoint = True
  917 
  918     def test_get_v3_catalog_project_non_exist(self):
  919         service = unit.new_service_ref()
  920         PROVIDERS.catalog_api.create_service(service['id'], service)
  921 
  922         malformed_url = "http://192.168.1.104:8774/v2/$(project)s"
  923         endpoint = unit.new_endpoint_ref(service_id=service['id'],
  924                                          url=malformed_url,
  925                                          region_id=None)
  926         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
  927         self.assertRaises(exception.ProjectNotFound,
  928                           PROVIDERS.catalog_api.get_v3_catalog,
  929                           'fake-user',
  930                           'fake-project')
  931 
  932     def test_get_v3_catalog_with_empty_public_url(self):
  933         service = unit.new_service_ref()
  934         PROVIDERS.catalog_api.create_service(service['id'], service)
  935 
  936         endpoint = unit.new_endpoint_ref(url='', service_id=service['id'],
  937                                          region_id=None)
  938         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
  939 
  940         catalog = PROVIDERS.catalog_api.get_v3_catalog(self.user_foo['id'],
  941                                                        self.project_bar['id'])
  942         catalog_endpoint = catalog[0]
  943         self.assertEqual(service['name'], catalog_endpoint['name'])
  944         self.assertEqual(service['id'], catalog_endpoint['id'])
  945         self.assertEqual([], catalog_endpoint['endpoints'])
  946 
  947     def test_create_endpoint_region_returns_not_found(self):
  948         service = unit.new_service_ref()
  949         PROVIDERS.catalog_api.create_service(service['id'], service)
  950 
  951         endpoint = unit.new_endpoint_ref(region_id=uuid.uuid4().hex,
  952                                          service_id=service['id'])
  953 
  954         self.assertRaises(exception.ValidationError,
  955                           PROVIDERS.catalog_api.create_endpoint,
  956                           endpoint['id'],
  957                           endpoint.copy())
  958 
  959     def test_create_region_invalid_id(self):
  960         region = unit.new_region_ref(id='0' * 256)
  961 
  962         self.assertRaises(exception.StringLengthExceeded,
  963                           PROVIDERS.catalog_api.create_region,
  964                           region)
  965 
  966     def test_create_region_invalid_parent_id(self):
  967         region = unit.new_region_ref(parent_region_id='0' * 256)
  968 
  969         self.assertRaises(exception.RegionNotFound,
  970                           PROVIDERS.catalog_api.create_region,
  971                           region)
  972 
  973     def test_delete_region_with_endpoint(self):
  974         # create a region
  975         region = unit.new_region_ref()
  976         PROVIDERS.catalog_api.create_region(region)
  977 
  978         # create a child region
  979         child_region = unit.new_region_ref(parent_region_id=region['id'])
  980         PROVIDERS.catalog_api.create_region(child_region)
  981         # create a service
  982         service = unit.new_service_ref()
  983         PROVIDERS.catalog_api.create_service(service['id'], service)
  984 
  985         # create an endpoint attached to the service and child region
  986         child_endpoint = unit.new_endpoint_ref(region_id=child_region['id'],
  987                                                service_id=service['id'])
  988 
  989         PROVIDERS.catalog_api.create_endpoint(
  990             child_endpoint['id'], child_endpoint
  991         )
  992         self.assertRaises(exception.RegionDeletionError,
  993                           PROVIDERS.catalog_api.delete_region,
  994                           child_region['id'])
  995 
  996         # create an endpoint attached to the service and parent region
  997         endpoint = unit.new_endpoint_ref(region_id=region['id'],
  998                                          service_id=service['id'])
  999 
 1000         PROVIDERS.catalog_api.create_endpoint(endpoint['id'], endpoint)
 1001         self.assertRaises(exception.RegionDeletionError,
 1002                           PROVIDERS.catalog_api.delete_region,
 1003                           region['id'])
 1004 
 1005     def test_v3_catalog_domain_scoped_token(self):
 1006         # test the case that project_id is None.
 1007         srv_1 = unit.new_service_ref()
 1008         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1009         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1010                                            region_id=None)
 1011         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1012 
 1013         srv_2 = unit.new_service_ref()
 1014         PROVIDERS.catalog_api.create_service(srv_2['id'], srv_2)
 1015         endpoint_2 = unit.new_endpoint_ref(service_id=srv_2['id'],
 1016                                            region_id=None)
 1017         PROVIDERS.catalog_api.create_endpoint(endpoint_2['id'], endpoint_2)
 1018 
 1019         self.config_fixture.config(group='endpoint_filter',
 1020                                    return_all_endpoints_if_no_filter=True)
 1021         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1022             uuid.uuid4().hex, None
 1023         )
 1024         self.assertThat(catalog_ref, matchers.HasLength(2))
 1025         self.config_fixture.config(group='endpoint_filter',
 1026                                    return_all_endpoints_if_no_filter=False)
 1027         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1028             uuid.uuid4().hex, None
 1029         )
 1030         self.assertThat(catalog_ref, matchers.HasLength(0))
 1031 
 1032     def test_v3_catalog_endpoint_filter_enabled(self):
 1033         srv_1 = unit.new_service_ref()
 1034         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1035         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1036                                            region_id=None)
 1037         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1038         endpoint_2 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1039                                            region_id=None)
 1040         PROVIDERS.catalog_api.create_endpoint(endpoint_2['id'], endpoint_2)
 1041         # create endpoint-project association.
 1042         PROVIDERS.catalog_api.add_endpoint_to_project(
 1043             endpoint_1['id'],
 1044             self.project_bar['id'])
 1045 
 1046         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1047             uuid.uuid4().hex, self.project_bar['id']
 1048         )
 1049         self.assertThat(catalog_ref, matchers.HasLength(1))
 1050         self.assertThat(catalog_ref[0]['endpoints'], matchers.HasLength(1))
 1051         # the endpoint is that defined in the endpoint-project association.
 1052         self.assertEqual(endpoint_1['id'],
 1053                          catalog_ref[0]['endpoints'][0]['id'])
 1054 
 1055     def test_v3_catalog_endpoint_filter_disabled(self):
 1056         # there is no endpoint-project association defined.
 1057         self.config_fixture.config(group='endpoint_filter',
 1058                                    return_all_endpoints_if_no_filter=True)
 1059         srv_1 = unit.new_service_ref()
 1060         PROVIDERS.catalog_api.create_service(srv_1['id'], srv_1)
 1061         endpoint_1 = unit.new_endpoint_ref(service_id=srv_1['id'],
 1062                                            region_id=None)
 1063         PROVIDERS.catalog_api.create_endpoint(endpoint_1['id'], endpoint_1)
 1064 
 1065         srv_2 = unit.new_service_ref()
 1066         PROVIDERS.catalog_api.create_service(srv_2['id'], srv_2)
 1067 
 1068         catalog_ref = PROVIDERS.catalog_api.get_v3_catalog(
 1069             uuid.uuid4().hex, self.project_bar['id']
 1070         )
 1071         self.assertThat(catalog_ref, matchers.HasLength(2))
 1072         srv_id_list = [catalog_ref[0]['id'], catalog_ref[1]['id']]
 1073         self.assertItemsEqual([srv_1['id'], srv_2['id']], srv_id_list)
 1074 
 1075 
 1076 class SqlPolicy(SqlTests, policy_tests.PolicyTests):
 1077     pass
 1078 
 1079 
 1080 class SqlInheritance(SqlTests, assignment_tests.InheritanceTests):
 1081     pass
 1082 
 1083 
 1084 class SqlImpliedRoles(SqlTests, assignment_tests.ImpliedRoleTests):
 1085     pass
 1086 
 1087 
 1088 class SqlFilterTests(SqlTests, identity_tests.FilterTests):
 1089 
 1090     def clean_up_entities(self):
 1091         """Clean up entity test data from Filter Test Cases."""
 1092         for entity in ['user', 'group', 'project']:
 1093             self._delete_test_data(entity, self.entity_list[entity])
 1094             self._delete_test_data(entity, self.domain1_entity_list[entity])
 1095         del self.entity_list
 1096         del self.domain1_entity_list
 1097         self.domain1['enabled'] = False
 1098         PROVIDERS.resource_api.update_domain(self.domain1['id'], self.domain1)
 1099         PROVIDERS.resource_api.delete_domain(self.domain1['id'])
 1100         del self.domain1
 1101 
 1102     def test_list_entities_filtered_by_domain(self):
 1103         # NOTE(henry-nash): This method is here rather than in
 1104         # unit.identity.test_backends since any domain filtering with LDAP is
 1105         # handled by the manager layer (and is already tested elsewhere) not at
 1106         # the driver level.
 1107         self.addCleanup(self.clean_up_entities)
 1108         self.domain1 = unit.new_domain_ref()
 1109         PROVIDERS.resource_api.create_domain(self.domain1['id'], self.domain1)
 1110 
 1111         self.entity_list = {}
 1112         self.domain1_entity_list = {}
 1113         for entity in ['user', 'group', 'project']:
 1114             # Create 5 entities, 3 of which are in domain1
 1115             DOMAIN1_ENTITIES = 3
 1116             self.entity_list[entity] = self._create_test_data(entity, 2)
 1117             self.domain1_entity_list[entity] = self._create_test_data(
 1118                 entity, DOMAIN1_ENTITIES, self.domain1['id'])
 1119 
 1120             # Should get back the DOMAIN1_ENTITIES in domain1
 1121             hints = driver_hints.Hints()
 1122             hints.add_filter('domain_id', self.domain1['id'])
 1123             entities = self._list_entities(entity)(hints=hints)
 1124             self.assertEqual(DOMAIN1_ENTITIES, len(entities))
 1125             self._match_with_list(entities, self.domain1_entity_list[entity])
 1126             # Check the driver has removed the filter from the list hints
 1127             self.assertFalse(hints.get_exact_filter_by_name('domain_id'))
 1128 
 1129     def test_filter_sql_injection_attack(self):
 1130         """Test against sql injection attack on filters.
 1131 
 1132         Test Plan:
 1133         - Attempt to get all entities back by passing a two-term attribute
 1134         - Attempt to piggyback filter to damage DB (e.g. drop table)
 1135 
 1136         """
 1137         # Check we have some users
 1138         users = PROVIDERS.identity_api.list_users()
 1139         self.assertGreater(len(users), 0)
 1140 
 1141         hints = driver_hints.Hints()
 1142         hints.add_filter('name', "anything' or 'x'='x")
 1143         users = PROVIDERS.identity_api.list_users(hints=hints)
 1144         self.assertEqual(0, len(users))
 1145 
 1146         # See if we can add a SQL command...use the group table instead of the
 1147         # user table since 'user' is reserved word for SQLAlchemy.
 1148         group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id)
 1149         group = PROVIDERS.identity_api.create_group(group)
 1150 
 1151         hints = driver_hints.Hints()
 1152         hints.add_filter('name', "x'; drop table group")
 1153         groups = PROVIDERS.identity_api.list_groups(hints=hints)
 1154         self.assertEqual(0, len(groups))
 1155 
 1156         groups = PROVIDERS.identity_api.list_groups()
 1157         self.assertGreater(len(groups), 0)
 1158 
 1159 
 1160 class SqlLimitTests(SqlTests, identity_tests.LimitTests):
 1161     def setUp(self):
 1162         super(SqlLimitTests, self).setUp()
 1163         identity_tests.LimitTests.setUp(self)
 1164 
 1165 
 1166 class FakeTable(sql.ModelBase):
 1167     __tablename__ = 'test_table'
 1168     col = sql.Column(sql.String(32), primary_key=True)
 1169 
 1170     @sql.handle_conflicts('keystone')
 1171     def insert(self):
 1172         raise db_exception.DBDuplicateEntry
 1173 
 1174     @sql.handle_conflicts('keystone')
 1175     def update(self):
 1176         raise db_exception.DBError(
 1177             inner_exception=exc.IntegrityError('a', 'a', 'a'))
 1178 
 1179     @sql.handle_conflicts('keystone')
 1180     def lookup(self):
 1181         raise KeyError
 1182 
 1183 
 1184 class SqlDecorators(unit.TestCase):
 1185 
 1186     def test_initialization_fail(self):
 1187         self.assertRaises(exception.StringLengthExceeded,
 1188                           FakeTable, col='a' * 64)
 1189 
 1190     def test_initialization(self):
 1191         tt = FakeTable(col='a')
 1192         self.assertEqual('a', tt.col)
 1193 
 1194     def test_conflict_happend(self):
 1195         self.assertRaises(exception.Conflict, FakeTable().insert)
 1196         self.assertRaises(exception.UnexpectedError, FakeTable().update)
 1197 
 1198     def test_not_conflict_error(self):
 1199         self.assertRaises(KeyError, FakeTable().lookup)
 1200 
 1201 
 1202 class SqlModuleInitialization(unit.TestCase):
 1203 
 1204     @mock.patch.object(sql.core, 'CONF')
 1205     @mock.patch.object(options, 'set_defaults')
 1206     def test_initialize_module(self, set_defaults, CONF):
 1207         sql.initialize()
 1208         set_defaults.assert_called_with(CONF,
 1209                                         connection='sqlite:///keystone.db')
 1210 
 1211 
 1212 class SqlCredential(SqlTests):
 1213 
 1214     def _create_credential_with_user_id(self, user_id=uuid.uuid4().hex):
 1215         credential = unit.new_credential_ref(user_id=user_id,
 1216                                              extra=uuid.uuid4().hex,
 1217                                              type=uuid.uuid4().hex)
 1218         PROVIDERS.credential_api.create_credential(
 1219             credential['id'], credential
 1220         )
 1221         return credential
 1222 
 1223     def _validateCredentialList(self, retrieved_credentials,
 1224                                 expected_credentials):
 1225         self.assertEqual(len(expected_credentials), len(retrieved_credentials))
 1226         retrived_ids = [c['id'] for c in retrieved_credentials]
 1227         for cred in expected_credentials:
 1228             self.assertIn(cred['id'], retrived_ids)
 1229 
 1230     def setUp(self):
 1231         self.useFixture(database.Database())
 1232         super(SqlCredential, self).setUp()
 1233         self.useFixture(
 1234             ksfixtures.KeyRepository(
 1235                 self.config_fixture,
 1236                 'credential',
 1237                 credential_provider.MAX_ACTIVE_KEYS
 1238             )
 1239         )
 1240 
 1241         self.credentials = []
 1242         for _ in range(3):
 1243             self.credentials.append(
 1244                 self._create_credential_with_user_id())
 1245         self.user_credentials = []
 1246         for _ in range(3):
 1247             cred = self._create_credential_with_user_id(self.user_foo['id'])
 1248             self.user_credentials.append(cred)
 1249             self.credentials.append(cred)
 1250 
 1251     def test_list_credentials(self):
 1252         credentials = PROVIDERS.credential_api.list_credentials()
 1253         self._validateCredentialList(credentials, self.credentials)
 1254         # test filtering using hints
 1255         hints = driver_hints.Hints()
 1256         hints.add_filter('user_id', self.user_foo['id'])
 1257         credentials = PROVIDERS.credential_api.list_credentials(hints)
 1258         self._validateCredentialList(credentials, self.user_credentials)
 1259 
 1260     def test_list_credentials_for_user(self):
 1261         credentials = PROVIDERS.credential_api.list_credentials_for_user(
 1262             self.user_foo['id'])
 1263         self._validateCredentialList(credentials, self.user_credentials)
 1264 
 1265     def test_list_credentials_for_user_and_type(self):
 1266         cred = self.user_credentials[0]
 1267         credentials = PROVIDERS.credential_api.list_credentials_for_user(
 1268             self.user_foo['id'], type=cred['type'])
 1269         self._validateCredentialList(credentials, [cred])
 1270 
 1271     def test_create_credential_is_encrypted_when_stored(self):
 1272         credential = unit.new_credential_ref(user_id=uuid.uuid4().hex)
 1273         credential_id = credential['id']
 1274         returned_credential = PROVIDERS.credential_api.create_credential(
 1275             credential_id,
 1276             credential
 1277         )
 1278 
 1279         # Make sure the `blob` is *not* encrypted when returned from the
 1280         # credential API.
 1281         self.assertEqual(returned_credential['blob'], credential['blob'])
 1282 
 1283         credential_from_backend = (
 1284             PROVIDERS.credential_api.driver.get_credential(credential_id)
 1285         )
 1286 
 1287         # Pull the credential directly from the backend, the `blob` should be
 1288         # encrypted.
 1289         self.assertNotEqual(
 1290             credential_from_backend['encrypted_blob'],
 1291             credential['blob']
 1292         )
 1293 
 1294     def test_list_credentials_is_decrypted(self):
 1295         credential = unit.new_credential_ref(user_id=uuid.uuid4().hex)
 1296         credential_id = credential['id']
 1297 
 1298         created_credential = PROVIDERS.credential_api.create_credential(
 1299             credential_id,
 1300             credential
 1301         )
 1302 
 1303         # Pull the credential directly from the backend, the `blob` should be
 1304         # encrypted.
 1305         credential_from_backend = (
 1306             PROVIDERS.credential_api.driver.get_credential(credential_id)
 1307         )
 1308         self.assertNotEqual(
 1309             credential_from_backend['encrypted_blob'],
 1310             credential['blob']
 1311         )
 1312 
 1313         # Make sure the `blob` values listed from the API are not encrypted.
 1314         listed_credentials = PROVIDERS.credential_api.list_credentials()
 1315         self.assertIn(created_credential, listed_credentials)
 1316 
 1317 
 1318 class SqlRegisteredLimit(SqlTests, limit_tests.RegisteredLimitTests):
 1319 
 1320     def setUp(self):
 1321         super(SqlRegisteredLimit, self).setUp()
 1322 
 1323         fixtures_to_cleanup = []
 1324         for service in default_fixtures.SERVICES:
 1325             service_id = service['id']
 1326             rv = PROVIDERS.catalog_api.create_service(service_id, service)
 1327             attrname = service['extra']['name']
 1328             setattr(self, attrname, rv)
 1329             fixtures_to_cleanup.append(attrname)
 1330         for region in default_fixtures.REGIONS:
 1331             rv = PROVIDERS.catalog_api.create_region(region)
 1332             attrname = region['id']
 1333             setattr(self, attrname, rv)
 1334             fixtures_to_cleanup.append(attrname)
 1335         self.addCleanup(self.cleanup_instance(*fixtures_to_cleanup))
 1336 
 1337 
 1338 class SqlLimit(SqlTests, limit_tests.LimitTests):
 1339 
 1340     def setUp(self):
 1341         super(SqlLimit, self).setUp()
 1342 
 1343         fixtures_to_cleanup = []
 1344         for service in default_fixtures.SERVICES:
 1345             service_id = service['id']
 1346             rv = PROVIDERS.catalog_api.create_service(service_id, service)
 1347             attrname = service['extra']['name']
 1348             setattr(self, attrname, rv)
 1349             fixtures_to_cleanup.append(attrname)
 1350         for region in default_fixtures.REGIONS:
 1351             rv = PROVIDERS.catalog_api.create_region(region)
 1352             attrname = region['id']
 1353             setattr(self, attrname, rv)
 1354             fixtures_to_cleanup.append(attrname)
 1355         self.addCleanup(self.cleanup_instance(*fixtures_to_cleanup))
 1356 
 1357         registered_limit_1 = unit.new_registered_limit_ref(
 1358             service_id=self.service_one['id'],
 1359             region_id=self.region_one['id'],
 1360             resource_name='volume', default_limit=10, id=uuid.uuid4().hex)
 1361         registered_limit_2 = unit.new_registered_limit_ref(
 1362             service_id=self.service_one['id'],
 1363             region_id=self.region_two['id'],
 1364             resource_name='snapshot', default_limit=10, id=uuid.uuid4().hex)
 1365         registered_limit_3 = unit.new_registered_limit_ref(
 1366             service_id=self.service_one['id'],
 1367             region_id=self.region_two['id'],
 1368             resource_name='backup', default_limit=10, id=uuid.uuid4().hex)
 1369         PROVIDERS.unified_limit_api.create_registered_limits(
 1370             [registered_limit_1, registered_limit_2, registered_limit_3])