Source code for vineyard.core.driver

#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2020-2023 Alibaba Group Holding Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import contextlib
import copy
import functools
import threading

from sortedcontainers import SortedDict

from vineyard.core.utils import find_most_precise_match


[docs]class DriverContext: def __init__(self): self._factory = SortedDict(dict) def __str__(self) -> str: return str(self._factory) def register(self, typename_prefix, meth, func): if typename_prefix not in self._factory: self._factory[typename_prefix] = dict() self._factory[typename_prefix][meth] = func def resolve(self, obj, typename): prefix, methods = find_most_precise_match(typename, self._factory) if prefix: for meth, func in methods.items(): # if shouldn't failed, since it has already been wrapped in during # resolving setattr(obj, meth, functools.partial(func, obj)) return obj def __call__(self, obj, typename): return self.resolve(obj, typename) def extend(self, drivers=None): driver = DriverContext() driver._factory.update(((k, copy.copy(v)) for k, v in self._factory.items())) if drivers: for ty, methods in drivers.items(): if ty not in self._factory: driver._factory[ty] = dict() driver._factory[ty].update(methods) return driver
default_driver_context = DriverContext() _driver_context_local = threading.local() _driver_context_local.default_driver = default_driver_context
[docs]def get_current_drivers(): '''Obtain current driver context.''' default_driver = getattr(_driver_context_local, 'default_driver', None) if not default_driver: default_driver = default_driver_context.extend() return default_driver
[docs]@contextlib.contextmanager def driver_context(drivers=None, base=None): """Open a new context for register drivers, without populting outside global environment. See Also: builder_context resolver_context """ current_driver = get_current_drivers() try: drivers = drivers or dict() base = base or current_driver local_driver = base.extend(drivers) _driver_context_local.default_driver = local_driver yield local_driver finally: _driver_context_local.default_driver = current_driver
def register_builtin_drivers(ctx): assert isinstance(ctx, DriverContext) # TODO # there's no builtin drivers yet. def registerize(func): """Registerize a method, add a `_factory` attribute and a `register` interface to a method. multiple-level register is automatically supported, users can >>> open.register(local_io_adaptor) >>> open.register(oss_io_adaptor) OR >>> open.register('file', local_io_adaptor) >>> open.register('odps', odps_io_adaptor) OR >>> open.register('file', 'csv', local_csv_reader) >>> open.register('file', 'tsv', local_tsv_reader) """ @functools.wraps(func) def wrap(*args, **kwargs): return func(*args, **kwargs) setattr(wrap, '_factory', None) def register(*args): if len(args) == 1: if wrap._factory is None: wrap._factory = [] if not isinstance(wrap._factory, list): raise RuntimeError( 'Invalid arguments: inconsistent with existing registerations' ) wrap._factory.append(args[0]) else: if wrap._factory is None: wrap._factory = {} if not isinstance(wrap._factory, dict): raise RuntimeError( 'Invalid arguments: inconsistent with existing registerations' ) root = wrap._factory for arg in args[:-2]: if arg not in root: root[arg] = dict() root = root[arg] if args[-2] not in root: root[args[-2]] = list() root[args[-2]].append(args[-1]) setattr(wrap, 'register', register) return wrap __all__ = [ 'default_driver_context', 'register_builtin_drivers', 'driver_context', 'get_current_drivers', 'registerize', ]