#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2020-2023 Alibaba Group Holding Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import copy
import functools
import threading
from sortedcontainers import SortedDict
from vineyard.core.utils import find_most_precise_match
[docs]class DriverContext:
def __init__(self):
self._factory = SortedDict(dict)
def __str__(self) -> str:
return str(self._factory)
def register(self, typename_prefix, meth, func):
if typename_prefix not in self._factory:
self._factory[typename_prefix] = dict()
self._factory[typename_prefix][meth] = func
def resolve(self, obj, typename):
prefix, methods = find_most_precise_match(typename, self._factory)
if prefix:
for meth, func in methods.items():
# if shouldn't failed, since it has already been wrapped in during
# resolving
setattr(obj, meth, functools.partial(func, obj))
return obj
def __call__(self, obj, typename):
return self.resolve(obj, typename)
def extend(self, drivers=None):
driver = DriverContext()
driver._factory.update(((k, copy.copy(v)) for k, v in self._factory.items()))
if drivers:
for ty, methods in drivers.items():
if ty not in self._factory:
driver._factory[ty] = dict()
driver._factory[ty].update(methods)
return driver
default_driver_context = DriverContext()
_driver_context_local = threading.local()
_driver_context_local.default_driver = default_driver_context
[docs]def get_current_drivers():
'''Obtain current driver context.'''
default_driver = getattr(_driver_context_local, 'default_driver', None)
if not default_driver:
default_driver = default_driver_context.extend()
return default_driver
[docs]@contextlib.contextmanager
def driver_context(drivers=None, base=None):
"""Open a new context for register drivers, without populting outside global
environment.
See Also:
builder_context
resolver_context
"""
current_driver = get_current_drivers()
try:
drivers = drivers or dict()
base = base or current_driver
local_driver = base.extend(drivers)
_driver_context_local.default_driver = local_driver
yield local_driver
finally:
_driver_context_local.default_driver = current_driver
def register_builtin_drivers(ctx):
assert isinstance(ctx, DriverContext)
# TODO
# there's no builtin drivers yet.
def registerize(func):
"""Registerize a method, add a `_factory` attribute and a `register`
interface to a method.
multiple-level register is automatically supported, users can
>>> open.register(local_io_adaptor)
>>> open.register(oss_io_adaptor)
OR
>>> open.register('file', local_io_adaptor)
>>> open.register('odps', odps_io_adaptor)
OR
>>> open.register('file', 'csv', local_csv_reader)
>>> open.register('file', 'tsv', local_tsv_reader)
"""
@functools.wraps(func)
def wrap(*args, **kwargs):
return func(*args, **kwargs)
setattr(wrap, '_factory', None)
def register(*args):
if len(args) == 1:
if wrap._factory is None:
wrap._factory = []
if not isinstance(wrap._factory, list):
raise RuntimeError(
'Invalid arguments: inconsistent with existing registerations'
)
wrap._factory.append(args[0])
else:
if wrap._factory is None:
wrap._factory = {}
if not isinstance(wrap._factory, dict):
raise RuntimeError(
'Invalid arguments: inconsistent with existing registerations'
)
root = wrap._factory
for arg in args[:-2]:
if arg not in root:
root[arg] = dict()
root = root[arg]
if args[-2] not in root:
root[args[-2]] = list()
root[args[-2]].append(args[-1])
setattr(wrap, 'register', register)
return wrap
__all__ = [
'default_driver_context',
'register_builtin_drivers',
'driver_context',
'get_current_drivers',
'registerize',
]