Commit b61c974b authored by Alexei Kornienko's avatar Alexei Kornienko
Browse files

Merge branch 'vitalii/swagger_generator' into 'master'

Vitalii/swagger generator

See merge request !32
parents b3281516 4166f922
import importlib
import importlib.util
import os
import sys
from copy import deepcopy
from datetime import datetime
from uuid import uuid4
from bson import ObjectId
from schematics import types
from schematics.contrib.mongo import ObjectIdType
class SwaggerSysFieldsError(Exception):
pass
def get_class_from_path(path: str):
module, class_name = path.rsplit('.', 1)
return getattr(importlib.import_module(module), class_name)
class SchemaTransformer:
def __init__(self, schema: dict):
self.schema = schema
def _get_model(self, schema, value):
return schema[value] if value in schema else self.schema[value]
def expand_schema(self, schema: dict) -> dict:
expanded_schema = {}
for model_name, model in schema.copy().items():
expanded_schema[model_name] = self.expand_schema_model(expanded_schema, model)
return expanded_schema
def _split_model_name(self, ref):
return ref.split('/')[-1]
def expand_schema_model(self, schema: dict, model: dict) -> dict:
if '$ref' in model:
return deepcopy(self._get_model(schema, self._split_model_name(model['$ref'])))
if 'allOf' in model:
schema_model = self._merge_model(schema, model)
elif 'anyOf' in model:
schema_model = {'anyOf': {
self._split_model_name(ref['$ref']): self.expand_schema_model(schema, ref)
for ref in model['anyOf']
}}
else:
schema_model = deepcopy(model)
if 'properties' in schema_model:
schema_model['properties'] = self._expand_props(schema, schema_model['properties'])
if 'items' in schema_model:
schema_model['items'] = self.expand_schema_model(schema, schema_model['items'])
if not schema_model:
raise ValueError(f'Error on expand schema, with model {model}')
return schema_model
def _merge_model(self, schema: dict, model: dict) -> dict:
all_of_obj = model.pop('allOf')
base_model_name = all_of_obj[0]['$ref'].split('/')[-1]
base_model = self._get_model(schema, base_model_name)
merged_model = deepcopy(base_model)
merged_model.update(model)
if len(all_of_obj) > 1:
model_all_of_obj = all_of_obj[1]
if 'properties' in model_all_of_obj:
merged_model['properties'].update(model_all_of_obj['properties'])
if 'required' in model_all_of_obj:
merged_model['required'] = model_all_of_obj['required']
return merged_model
def _expand_props(self, schema: dict, props: dict) -> dict:
return {k: self.expand_schema_model(schema, v) for k, v in props.items()}
class BaseSchemaBuilder:
_schemas = {}
models_dir = 'models'
models_module = 'schematics.models'
base_model = f'{models_module}.Model'
def __init__(self, schema, module):
swagger_schema = self._transform_schemas(schema)
self._schemas.update(swagger_schema)
self.swagger_schema = deepcopy(swagger_schema)
self.module = self._get_module(module)
self.known_props = {
'enum': 'choices',
'maxLength': 'max_length',
'minLength': 'min_length',
'pattern': 'regex',
'is_required': 'required',
'minimum': 'min_value',
'maximum': 'max_value',
'default': 'default',
'minItems': 'min_size',
'maxItems': 'max_size',
'x-serialize_when_none': 'serialize_when_none',
}
self.custom_default = {
'hex': lambda: uuid4().hex,
'now': datetime.now,
'object_id': ObjectId
}
self.data_types = {
'boolean': lambda x: types.BooleanType(**self._map_params(x)),
'int32': lambda x: types.IntType(**self._map_params(x)),
'int64': lambda x: types.LongType(**self._map_params(x)),
'float': lambda x: types.FloatType(**self._map_params(x)),
'double': lambda x: types.DecimalType(**self._map_params(x)),
'string': lambda x: self._create_string_type(x),
'uuid': lambda x: types.UUIDType(**self._map_params(x)),
'md5': lambda x: types.MD5Type(**self._map_params(x)),
'sha1': lambda x: types.SHA1Type(**self._map_params(x)),
'date': lambda x: types.DateType(**self._map_params(x)),
'date-time': lambda x: types.UTCDateTimeType(**self._map_params(x)),
'email': lambda x: types.EmailType(**self._map_params(x)),
'uri': lambda x: types.URLType(**self._map_params(x)),
'array': lambda x: self._create_array_type(types.ListType, x),
'object': self._build_model,
'polymodel': self._build_poly_model,
'multidict': lambda x: types.MultiDictType(**self._map_params(x)),
'decimal-float': lambda x: types.DecimalIntegerType(**self._map_params(x)),
'object-id': lambda x: ObjectIdType(**self._map_params(x)),
'timer': lambda x: types.TimerType(**self._map_params(x)),
}
def _build_poly_model(self, prop_data):
models = []
model_props = self._map_params(prop_data)
for name, model in prop_data['items']['anyOf'].items():
models.append(self._create_schema_model_class(name.split('.')[-1], model))
poly_model = types.PolyModelType(
model_spec=models,
claim_function=get_class_from_path(prop_data['x-claim-func']),
required=model_props.get('is_required', False)
)
if 'type' in prop_data and prop_data['type'] == 'array':
poly_model = types.ListType(poly_model, **model_props)
return poly_model
def _transform_schemas(self, schema: dict) -> dict:
return SchemaTransformer(self._schemas).expand_schema(schema)
@classmethod
def _get_module(cls, source_module: str):
import_path = f'{cls.models_module}.{source_module}'
try:
module = importlib.import_module(import_path)
except ModuleNotFoundError:
module_location = os.path.abspath(f'{cls.models_dir}/{source_module}.py')
spec = importlib.util.spec_from_file_location(import_path, module_location)
module = importlib.util.module_from_spec(spec)
sys.modules[import_path] = module
return module
def _map_params(self, prop_data: dict) -> dict:
params = {}
for name, value in prop_data.items():
if name in self.known_props:
params[self.known_props[name]] = value
if 'x-default' in prop_data:
params['default'] = self.custom_default[prop_data['x-default']]
return params
def _build_field(self, prop_data: dict):
field_type_key = None
for key in ('x-format', 'format', 'type'):
if key in prop_data:
field_type_key = prop_data[key]
break
return self.data_types[field_type_key](prop_data)
def _build_model(self, prop_data: dict):
model = self._create_schema_model_class(prop_data['model_name'], prop_data)
params = self._map_params(prop_data)
params.update({'model_spec': model, 'required': prop_data.get('is_required', False)})
return types.ModelType(**params)
def _create_string_type(self, prop_data: dict):
params = self._map_params(prop_data)
return types.StringType(**params)
def _create_array_type(self, field_type, prop_data):
items_data = prop_data.pop('items')
items_data['model_name'] = prop_data['model_name']
params = self._map_params(prop_data)
params['field'] = self._build_field(items_data)
return field_type(**params)
def _create_schema_model_class(self, model_name: str, model_info: dict):
parent_class = get_class_from_path(model_info.get('x-baseClass', self.base_model))
props = model_info.get('properties', {})
required_fields = model_info.get('required', set())
fields = {}
for prop_name, prop_info in props.items():
if prop_name in required_fields:
prop_info['is_required'] = True
prop_info['model_name'] = f'{model_name}{prop_name.capitalize()}'
fields[prop_name] = self._build_field(prop_info)
model = type(model_name, (parent_class,), fields)
sys_fields = getattr(model, f'_{parent_class.__name__}__sys_fields', set())
missed_fields = sys_fields - set(dir(model))
if missed_fields:
raise SwaggerSysFieldsError(f'SysFields: {missed_fields} are not defined')
setattr(self.module, model_name, model)
return model
def generate_schema_from_swagger(self):
for model_name, model_info in self.swagger_schema.items():
model_name = model_name.split('.')[-1]
self._create_schema_model_class(model_name, model_info)
import pytest
from schematics import types
from schematics.exceptions import DataError
from schematics.swagger_generator import BaseSchemaBuilder
def _get_test_cases_models():
return {
'_schema': {
'models.TestValue': {
'title': 'TestValue',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'TestValue model',
'properties': {
'currency': {
'type': 'string',
'enum': ['EUR', 'UAH', 'USD'],
'x-legalNameUa': 'Валюта',
'x-legalNameEn': 'Currency',
},
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount',
}
},
'required': ['amount', 'currency']
},
'models.EuroPayment': {
'title': 'EuroPayment',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'EuroPayment model',
'properties': {
'value': {
'allOf': [{'$ref': '#/components/schemas/models.TestValue'}, {'type': 'object', 'properties': {
'currency': {
'type': 'string',
'enum': ['EUR', ],
'x-legalNameUa': 'Евро',
'x-legalNameEn': 'Euro',
}
}, 'required': ['currency', ]}]
}
},
'required': ['value', ]
},
'models.Bank': {
'title': 'Bank',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'Bank model',
'properties': {
'name': {
'type': 'string',
'maxLength': 100,
'x-legalNameUa': 'Назва',
'x-legalNameEn': 'Name',
},
'payments': {
'type': 'array',
'default': [],
'items': {
'allOf': [{'$ref': '#/components/schemas/models.EuroPayment'},
{'type': 'object', 'properties': {
'currency': {
'type': 'string',
'enum': ['EUR', 'USD']
}
}}]
}
},
'other_payments': {
'type': 'array',
'default': [],
'items': {
'allOf': [{'$ref': '#/components/schemas/models.EuroPayment'},
{'type': 'object', 'properties': {
'currency': {
'type': 'string',
'enum': ['UAN', 'eurocent']
}
}}]
}
}
},
'required': ['name', 'payments']
},
'models.CustomValueFieldOverriding': {
'title': 'models.CustomValue',
'allOf': [
{'$ref': '#/components/schemas/models.TestValue'},
{
'type': 'object',
'properties': {
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount',
}
}
}
]
},
'models.CustomValueRequirementsOverriding': {
'title': 'models.CustomValue',
'allOf': [
{'$ref': '#/components/schemas/models.TestValue'},
{
'type': 'object',
'properties': {
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount',
}
},
'required': ['currency']
}
]
}
}
}
def _get_required_test_cases_models():
return {
'_schema': {
'models.A': {
'title': 'models.A',
'properties': {
'field_1': {
'type': 'string'
},
'field_2': {
'type': 'string'
}
},
'required': ['field_1']
},
'models.B': {
'title': 'models.B',
'allOf': [
{'$ref': '#/components/schemas/models.A'},
{
'type': 'object',
'properties': {
'field_1': {
'type': 'boolean'
}
},
'required': ['field_1', 'field_2']
}
],
},
'models.C': {
'title': 'models.C',
'allOf': [
{'$ref': '#/components/schemas/models.A'},
{
'type': 'object',
'properties': {
'field_2': {
'type': 'boolean'
}
},
'required': ['field_2']
}
],
},
'models.D': {
'title': 'models.D',
'allOf': [
{'$ref': '#/components/schemas/models.A'},
{
'type': 'object',
'required': ['field_2']
}
],
},
'models.E': {
'title': 'models.E',
'allOf': [
{'$ref': '#/components/schemas/models.A'},
{
'type': 'object',
'required': []
}
],
}
}
}
def test_fields_model_inheritance():
test_class_schema = _get_test_cases_models()['_schema']
schema = BaseSchemaBuilder(test_class_schema, 'models')
schema.generate_schema_from_swagger()
cls = getattr(schema.module, 'EuroPayment')
assert isinstance(cls.fields['value'].fields['currency'], types.StringType)
expanded_model = {
'title': 'EuroPayment',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'EuroPayment model',
'properties': {
'value': {
'title': 'TestValue',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'TestValue model',
'properties': {
'currency': {
'type': 'string',
'enum': ['EUR'],
'x-legalNameUa': 'Евро',
'x-legalNameEn': 'Euro',
},
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount',
}
},
'required': ['currency']
}
},
'required': ['value']
}
assert schema._schemas['models.EuroPayment'] == expanded_model
assert cls.fields['value'].fields['currency'].choices == ['EUR']
assert cls.fields['value'].fields['currency'].required is True
assert cls.fields['value'].required is True
def test_nested_fields_model_inheritance():
test_class_schema = _get_test_cases_models()['_schema']
schema = BaseSchemaBuilder(test_class_schema, 'models')
schema.generate_schema_from_swagger()
cls = getattr(schema.module, 'Bank')
expanded_model = {
'title': 'Bank',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'Bank model',
'properties': {
'name': {
'type': 'string',
'maxLength': 100,
'x-legalNameUa': 'Назва',
'x-legalNameEn': 'Name'},
'payments': {
'type': 'array',
'default': [],
'items': {
'title': 'EuroPayment',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'EuroPayment model',
'properties': {
'value': {
'title': 'TestValue',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'TestValue model',
'properties': {
'currency': {
'type': 'string',
'enum': ['EUR'],
'x-legalNameUa': 'Евро',
'x-legalNameEn': 'Euro'
},
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount'
}
},
'required': ['currency']
},
'currency': {
'type': 'string',
'enum': ['EUR', 'USD']
}
},
'required': ['value']
}
},
'other_payments': {
'type': 'array',
'default': [],
'items': {
'title': 'EuroPayment',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'EuroPayment model',
'properties': {
'value': {
'title': 'TestValue',
'x-baseClass': 'schematics.models.Model',
'type': 'object',
'description': 'TestValue model',
'properties': {
'currency': {
'type': 'string',
'enum': ['EUR'],
'x-legalNameUa': 'Евро',
'x-legalNameEn': 'Euro'
},
'amount': {
'type': 'number',
'format': 'float',
'x-format': 'decimal-float',
'x-legalNameUa': 'Сума',
'x-legalNameEn': 'Amount'
}
},
'required': ['currency']
},
'currency': {
'type': 'string',
'enum': ['UAN', 'eurocent']
}
},
'required': ['value']
}
}
},
'required': ['name', 'payments']
}
assert schema._schemas['models.Bank'] == expanded_model
assert isinstance(cls.fields['payments'], types.ListType)
assert cls.fields['payments'].field.fields['currency'].choices == ['EUR', 'USD']
assert cls.fields['payments'].required is True
assert cls.fields['name'].required is True
assert isinstance(cls.fields['other_payments'], types.ListType)
assert cls.fields['other_payments'].required is False
assert cls.fields['other_payments'].field.fields['currency'].choices == ['UAN', 'eurocent']
def test_overriding_parent_field_stay_required():
test_class_schema = _get_test_cases_models()['_schema']
schema = BaseSchemaBuilder(test_class_schema, 'models')
schema.generate_schema_from_swagger()
cls = getattr(schema.module, 'CustomValueFieldOverriding')
with pytest.raises(DataError, match='{"currency": "This field is required", "amount": "This field is required"}'):
cls()
def test_required_fields_overriding():
test_class_schema = _get_test_cases_models()['_schema']
schema = BaseSchemaBuilder(test_class_schema, 'models')
cls = getattr(schema.module, 'CustomValueRequirementsOverriding')
with pytest.raises(DataError, match='{"currency": "This field is required"}'):
cls()
def test_required_param():
test_class_schema = _get_required_test_cases_models()['_schema']
schema = BaseSchemaBuilder(test_class_schema, 'models')
schema.generate_schema_from_swagger()
a = getattr(schema.module, 'A')
assert a.fields['field_1'].required is True
assert a.fields['field_2'].required is False
b = getattr(schema.module, 'B')
assert b.fields['field_1'].required is True
assert b.fields['field_2'].required is True
c = getattr(schema.module, 'C')
assert c.fields['field_1'].required is False
assert c.fields['field_2'].required is True