"Fossies" - the Fresh Open Source Software Archive

Member "keystone-17.0.0/keystone/federation/backends/sql.py" (13 May 2020, 15601 Bytes) of package /linux/misc/openstack/keystone-17.0.0.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. For more information about "sql.py" see the Fossies "Dox" file reference documentation and the latest Fossies "Diffs" side-by-side code changes report: 16.0.1_vs_17.0.0.

    1 # Copyright 2014 OpenStack Foundation
    2 #
    3 # Licensed under the Apache License, Version 2.0 (the "License"); you may
    4 # not use this file except in compliance with the License. You may obtain
    5 # a copy of the License at
    6 #
    7 #      http://www.apache.org/licenses/LICENSE-2.0
    8 #
    9 # Unless required by applicable law or agreed to in writing, software
   10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   12 # License for the specific language governing permissions and limitations
   13 # under the License.
   14 
   15 from oslo_log import log
   16 from oslo_serialization import jsonutils
   17 from sqlalchemy import orm
   18 
   19 from keystone.common import sql
   20 from keystone import exception
   21 from keystone.federation.backends import base
   22 from keystone.i18n import _
   23 
   24 
   25 LOG = log.getLogger(__name__)
   26 
   27 
   28 class FederationProtocolModel(sql.ModelBase, sql.ModelDictMixin):
   29     __tablename__ = 'federation_protocol'
   30     attributes = ['id', 'idp_id', 'mapping_id', 'remote_id_attribute']
   31     mutable_attributes = frozenset(['mapping_id', 'remote_id_attribute'])
   32 
   33     id = sql.Column(sql.String(64), primary_key=True)
   34     idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
   35                         ondelete='CASCADE'), primary_key=True)
   36     mapping_id = sql.Column(sql.String(64), nullable=False)
   37     remote_id_attribute = sql.Column(sql.String(64))
   38 
   39     @classmethod
   40     def from_dict(cls, dictionary):
   41         new_dictionary = dictionary.copy()
   42         return cls(**new_dictionary)
   43 
   44     def to_dict(self):
   45         """Return a dictionary with model's attributes."""
   46         d = dict()
   47         for attr in self.__class__.attributes:
   48             d[attr] = getattr(self, attr)
   49         return d
   50 
   51 
   52 class IdentityProviderModel(sql.ModelBase, sql.ModelDictMixin):
   53     __tablename__ = 'identity_provider'
   54     attributes = ['id', 'domain_id', 'enabled', 'description', 'remote_ids',
   55                   'authorization_ttl']
   56     mutable_attributes = frozenset(['description', 'enabled', 'remote_ids',
   57                                     'authorization_ttl'])
   58 
   59     id = sql.Column(sql.String(64), primary_key=True)
   60     domain_id = sql.Column(sql.String(64), nullable=False)
   61     enabled = sql.Column(sql.Boolean, nullable=False)
   62     description = sql.Column(sql.Text(), nullable=True)
   63     authorization_ttl = sql.Column(sql.Integer, nullable=True)
   64 
   65     remote_ids = orm.relationship('IdPRemoteIdsModel',
   66                                   order_by='IdPRemoteIdsModel.remote_id',
   67                                   cascade='all, delete-orphan')
   68     expiring_user_group_memberships = orm.relationship(
   69         'ExpiringUserGroupMembership',
   70         cascade='all, delete-orphan',
   71         backref="idp"
   72     )
   73 
   74     @classmethod
   75     def from_dict(cls, dictionary):
   76         new_dictionary = dictionary.copy()
   77         remote_ids_list = new_dictionary.pop('remote_ids', None)
   78         if not remote_ids_list:
   79             remote_ids_list = []
   80         identity_provider = cls(**new_dictionary)
   81         remote_ids = []
   82         # NOTE(fmarco76): the remote_ids_list contains only remote ids
   83         # associated with the IdP because of the "relationship" established in
   84         # sqlalchemy and corresponding to the FK in the idp_remote_ids table
   85         for remote in remote_ids_list:
   86             remote_ids.append(IdPRemoteIdsModel(remote_id=remote))
   87         identity_provider.remote_ids = remote_ids
   88         return identity_provider
   89 
   90     def to_dict(self):
   91         """Return a dictionary with model's attributes."""
   92         d = dict()
   93         for attr in self.__class__.attributes:
   94             d[attr] = getattr(self, attr)
   95         d['remote_ids'] = []
   96         for remote in self.remote_ids:
   97             d['remote_ids'].append(remote.remote_id)
   98         return d
   99 
  100 
  101 class IdPRemoteIdsModel(sql.ModelBase, sql.ModelDictMixin):
  102     __tablename__ = 'idp_remote_ids'
  103     attributes = ['idp_id', 'remote_id']
  104     mutable_attributes = frozenset(['idp_id', 'remote_id'])
  105 
  106     idp_id = sql.Column(sql.String(64),
  107                         sql.ForeignKey('identity_provider.id',
  108                                        ondelete='CASCADE'))
  109     remote_id = sql.Column(sql.String(255),
  110                            primary_key=True)
  111 
  112     @classmethod
  113     def from_dict(cls, dictionary):
  114         new_dictionary = dictionary.copy()
  115         return cls(**new_dictionary)
  116 
  117     def to_dict(self):
  118         """Return a dictionary with model's attributes."""
  119         d = dict()
  120         for attr in self.__class__.attributes:
  121             d[attr] = getattr(self, attr)
  122         return d
  123 
  124 
  125 class MappingModel(sql.ModelBase, sql.ModelDictMixin):
  126     __tablename__ = 'mapping'
  127     attributes = ['id', 'rules']
  128 
  129     id = sql.Column(sql.String(64), primary_key=True)
  130     rules = sql.Column(sql.JsonBlob(), nullable=False)
  131 
  132     @classmethod
  133     def from_dict(cls, dictionary):
  134         new_dictionary = dictionary.copy()
  135         new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules'])
  136         return cls(**new_dictionary)
  137 
  138     def to_dict(self):
  139         """Return a dictionary with model's attributes."""
  140         d = dict()
  141         for attr in self.__class__.attributes:
  142             d[attr] = getattr(self, attr)
  143         d['rules'] = jsonutils.loads(d['rules'])
  144         return d
  145 
  146 
  147 class ServiceProviderModel(sql.ModelBase, sql.ModelDictMixin):
  148     __tablename__ = 'service_provider'
  149     attributes = ['auth_url', 'id', 'enabled', 'description',
  150                   'relay_state_prefix', 'sp_url']
  151     mutable_attributes = frozenset(['auth_url', 'description', 'enabled',
  152                                     'relay_state_prefix', 'sp_url'])
  153 
  154     id = sql.Column(sql.String(64), primary_key=True)
  155     enabled = sql.Column(sql.Boolean, nullable=False)
  156     description = sql.Column(sql.Text(), nullable=True)
  157     auth_url = sql.Column(sql.String(256), nullable=False)
  158     sp_url = sql.Column(sql.String(256), nullable=False)
  159     relay_state_prefix = sql.Column(sql.String(256), nullable=False)
  160 
  161     @classmethod
  162     def from_dict(cls, dictionary):
  163         new_dictionary = dictionary.copy()
  164         return cls(**new_dictionary)
  165 
  166     def to_dict(self):
  167         """Return a dictionary with model's attributes."""
  168         d = dict()
  169         for attr in self.__class__.attributes:
  170             d[attr] = getattr(self, attr)
  171         return d
  172 
  173 
  174 class Federation(base.FederationDriverBase):
  175 
  176     _CONFLICT_LOG_MSG = 'Conflict %(conflict_type)s: %(details)s'
  177 
  178     def _handle_idp_conflict(self, e):
  179         conflict_type = 'identity_provider'
  180         details = str(e)
  181         LOG.debug(self._CONFLICT_LOG_MSG, {'conflict_type': conflict_type,
  182                                            'details': details})
  183         if 'remote_id' in details:
  184             msg = _('Duplicate remote ID: %s')
  185         else:
  186             msg = _('Duplicate entry: %s')
  187         msg = msg % e.value
  188         raise exception.Conflict(type=conflict_type, details=msg)
  189 
  190     # Identity Provider CRUD
  191     def create_idp(self, idp_id, idp):
  192         idp['id'] = idp_id
  193         try:
  194             with sql.session_for_write() as session:
  195                 idp_ref = IdentityProviderModel.from_dict(idp)
  196                 session.add(idp_ref)
  197                 return idp_ref.to_dict()
  198         except sql.DBDuplicateEntry as e:
  199             self._handle_idp_conflict(e)
  200 
  201     def delete_idp(self, idp_id):
  202         with sql.session_for_write() as session:
  203             self._delete_assigned_protocols(session, idp_id)
  204             idp_ref = self._get_idp(session, idp_id)
  205             session.delete(idp_ref)
  206 
  207     def _get_idp(self, session, idp_id):
  208         idp_ref = session.query(IdentityProviderModel).get(idp_id)
  209         if not idp_ref:
  210             raise exception.IdentityProviderNotFound(idp_id=idp_id)
  211         return idp_ref
  212 
  213     def _get_idp_from_remote_id(self, session, remote_id):
  214         q = session.query(IdPRemoteIdsModel)
  215         q = q.filter_by(remote_id=remote_id)
  216         try:
  217             return q.one()
  218         except sql.NotFound:
  219             raise exception.IdentityProviderNotFound(idp_id=remote_id)
  220 
  221     def list_idps(self, hints=None):
  222         with sql.session_for_read() as session:
  223             query = session.query(IdentityProviderModel)
  224             idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
  225             idps_list = [idp.to_dict() for idp in idps]
  226             return idps_list
  227 
  228     def get_idp(self, idp_id):
  229         with sql.session_for_read() as session:
  230             idp_ref = self._get_idp(session, idp_id)
  231             return idp_ref.to_dict()
  232 
  233     def get_idp_from_remote_id(self, remote_id):
  234         with sql.session_for_read() as session:
  235             ref = self._get_idp_from_remote_id(session, remote_id)
  236             return ref.to_dict()
  237 
  238     def update_idp(self, idp_id, idp):
  239         try:
  240             with sql.session_for_write() as session:
  241                 idp_ref = self._get_idp(session, idp_id)
  242                 old_idp = idp_ref.to_dict()
  243                 old_idp.update(idp)
  244                 new_idp = IdentityProviderModel.from_dict(old_idp)
  245                 for attr in IdentityProviderModel.mutable_attributes:
  246                     setattr(idp_ref, attr, getattr(new_idp, attr))
  247                 return idp_ref.to_dict()
  248         except sql.DBDuplicateEntry as e:
  249             self._handle_idp_conflict(e)
  250 
  251     # Protocol CRUD
  252     def _get_protocol(self, session, idp_id, protocol_id):
  253         q = session.query(FederationProtocolModel)
  254         q = q.filter_by(id=protocol_id, idp_id=idp_id)
  255         try:
  256             return q.one()
  257         except sql.NotFound:
  258             kwargs = {'protocol_id': protocol_id,
  259                       'idp_id': idp_id}
  260             raise exception.FederatedProtocolNotFound(**kwargs)
  261 
  262     @sql.handle_conflicts(conflict_type='federation_protocol')
  263     def create_protocol(self, idp_id, protocol_id, protocol):
  264         protocol['id'] = protocol_id
  265         protocol['idp_id'] = idp_id
  266         with sql.session_for_write() as session:
  267             self._get_idp(session, idp_id)
  268             protocol_ref = FederationProtocolModel.from_dict(protocol)
  269             session.add(protocol_ref)
  270             return protocol_ref.to_dict()
  271 
  272     def update_protocol(self, idp_id, protocol_id, protocol):
  273         with sql.session_for_write() as session:
  274             proto_ref = self._get_protocol(session, idp_id, protocol_id)
  275             old_proto = proto_ref.to_dict()
  276             old_proto.update(protocol)
  277             new_proto = FederationProtocolModel.from_dict(old_proto)
  278             for attr in FederationProtocolModel.mutable_attributes:
  279                 setattr(proto_ref, attr, getattr(new_proto, attr))
  280             return proto_ref.to_dict()
  281 
  282     def get_protocol(self, idp_id, protocol_id):
  283         with sql.session_for_read() as session:
  284             protocol_ref = self._get_protocol(session, idp_id, protocol_id)
  285             return protocol_ref.to_dict()
  286 
  287     def list_protocols(self, idp_id):
  288         with sql.session_for_read() as session:
  289             q = session.query(FederationProtocolModel)
  290             q = q.filter_by(idp_id=idp_id)
  291             protocols = [protocol.to_dict() for protocol in q]
  292             return protocols
  293 
  294     def delete_protocol(self, idp_id, protocol_id):
  295         with sql.session_for_write() as session:
  296             key_ref = self._get_protocol(session, idp_id, protocol_id)
  297             session.delete(key_ref)
  298 
  299     def _delete_assigned_protocols(self, session, idp_id):
  300         query = session.query(FederationProtocolModel)
  301         query = query.filter_by(idp_id=idp_id)
  302         query.delete()
  303 
  304     # Mapping CRUD
  305     def _get_mapping(self, session, mapping_id):
  306         mapping_ref = session.query(MappingModel).get(mapping_id)
  307         if not mapping_ref:
  308             raise exception.MappingNotFound(mapping_id=mapping_id)
  309         return mapping_ref
  310 
  311     @sql.handle_conflicts(conflict_type='mapping')
  312     def create_mapping(self, mapping_id, mapping):
  313         ref = {}
  314         ref['id'] = mapping_id
  315         ref['rules'] = mapping.get('rules')
  316         with sql.session_for_write() as session:
  317             mapping_ref = MappingModel.from_dict(ref)
  318             session.add(mapping_ref)
  319             return mapping_ref.to_dict()
  320 
  321     def delete_mapping(self, mapping_id):
  322         with sql.session_for_write() as session:
  323             mapping_ref = self._get_mapping(session, mapping_id)
  324             session.delete(mapping_ref)
  325 
  326     def list_mappings(self):
  327         with sql.session_for_read() as session:
  328             mappings = session.query(MappingModel)
  329             return [x.to_dict() for x in mappings]
  330 
  331     def get_mapping(self, mapping_id):
  332         with sql.session_for_read() as session:
  333             mapping_ref = self._get_mapping(session, mapping_id)
  334             return mapping_ref.to_dict()
  335 
  336     @sql.handle_conflicts(conflict_type='mapping')
  337     def update_mapping(self, mapping_id, mapping):
  338         ref = {}
  339         ref['id'] = mapping_id
  340         ref['rules'] = mapping.get('rules')
  341         with sql.session_for_write() as session:
  342             mapping_ref = self._get_mapping(session, mapping_id)
  343             old_mapping = mapping_ref.to_dict()
  344             old_mapping.update(ref)
  345             new_mapping = MappingModel.from_dict(old_mapping)
  346             for attr in MappingModel.attributes:
  347                 setattr(mapping_ref, attr, getattr(new_mapping, attr))
  348             return mapping_ref.to_dict()
  349 
  350     def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
  351         with sql.session_for_read() as session:
  352             protocol_ref = self._get_protocol(session, idp_id, protocol_id)
  353             mapping_id = protocol_ref.mapping_id
  354             mapping_ref = self._get_mapping(session, mapping_id)
  355             return mapping_ref.to_dict()
  356 
  357     # Service Provider CRUD
  358     @sql.handle_conflicts(conflict_type='service_provider')
  359     def create_sp(self, sp_id, sp):
  360         sp['id'] = sp_id
  361         with sql.session_for_write() as session:
  362             sp_ref = ServiceProviderModel.from_dict(sp)
  363             session.add(sp_ref)
  364             return sp_ref.to_dict()
  365 
  366     def delete_sp(self, sp_id):
  367         with sql.session_for_write() as session:
  368             sp_ref = self._get_sp(session, sp_id)
  369             session.delete(sp_ref)
  370 
  371     def _get_sp(self, session, sp_id):
  372         sp_ref = session.query(ServiceProviderModel).get(sp_id)
  373         if not sp_ref:
  374             raise exception.ServiceProviderNotFound(sp_id=sp_id)
  375         return sp_ref
  376 
  377     def list_sps(self, hints=None):
  378         with sql.session_for_read() as session:
  379             query = session.query(ServiceProviderModel)
  380             sps = sql.filter_limit_query(ServiceProviderModel, query, hints)
  381             sps_list = [sp.to_dict() for sp in sps]
  382             return sps_list
  383 
  384     def get_sp(self, sp_id):
  385         with sql.session_for_read() as session:
  386             sp_ref = self._get_sp(session, sp_id)
  387             return sp_ref.to_dict()
  388 
  389     def update_sp(self, sp_id, sp):
  390         with sql.session_for_write() as session:
  391             sp_ref = self._get_sp(session, sp_id)
  392             old_sp = sp_ref.to_dict()
  393             old_sp.update(sp)
  394             new_sp = ServiceProviderModel.from_dict(old_sp)
  395             for attr in ServiceProviderModel.mutable_attributes:
  396                 setattr(sp_ref, attr, getattr(new_sp, attr))
  397             return sp_ref.to_dict()
  398 
  399     def get_enabled_service_providers(self):
  400         with sql.session_for_read() as session:
  401             service_providers = session.query(ServiceProviderModel)
  402             service_providers = service_providers.filter_by(enabled=True)
  403             return service_providers