Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Apr 30, 2024
1 parent ecb51a8 commit 396e40c
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 54 deletions.
6 changes: 2 additions & 4 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from abc import ABC, abstractmethod
from typing import Optional

import yaml

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
Expand All @@ -18,6 +16,7 @@
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map


Expand Down Expand Up @@ -154,8 +153,7 @@ def predefined_models(self) -> list[AIModelEntity]:
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)

new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from abc import ABC, abstractmethod

import yaml

from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source


Expand Down Expand Up @@ -44,10 +43,7 @@ def get_provider_schema(self) -> ProviderEntity:

# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(yaml_path, ignore_error=True)

try:
# yaml_data to entity
Expand Down
34 changes: 16 additions & 18 deletions api/core/tools/provider/builtin_tool_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from os import listdir, path
from typing import Any

from yaml import FullLoader, load

from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import (
Expand All @@ -15,6 +13,7 @@
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source


Expand All @@ -28,10 +27,9 @@ def __init__(self, **data: Any) -> None:
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
with open(yaml_path, 'rb') as f:
provider_yaml = load(f.read(), FullLoader)
except:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
provider_yaml = load_yaml_file(yaml_path)
except Exception as e:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')

if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
# set credentials name
Expand All @@ -58,18 +56,18 @@ def _get_builtin_tools(self) -> list[Tool]:
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file))

# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))

self.tools = tools
return tools
Expand Down
18 changes: 9 additions & 9 deletions api/core/tools/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Union

from pydantic import BaseModel
from yaml import FullLoader, load

from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
Expand All @@ -16,6 +15,7 @@
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from core.tools.utils.yaml_utils import load_yaml_file


class ToolConfigurationManager(BaseModel):
Expand Down Expand Up @@ -254,14 +254,14 @@ def _init_configuration(cls):

for file in files:
provider = file.split('.')[0]
with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
models = configurations.models or []
for model in models:
model_key = f'{provider}.{model.model}'
cls._model_configurations[model_key] = model

cls._configurations[provider] = configurations
yaml_data = load_yaml_file(os.path.join(model_tools_path, file))
configurations = ModelToolProviderConfiguration(**yaml_data)
models = configurations.models or []
for model in models:
model_key = f'{provider}.{model.model}'
cls._model_configurations[model_key] = model

cls._configurations[provider] = configurations
cls._inited = True

@classmethod
Expand Down
31 changes: 31 additions & 0 deletions api/core/tools/utils/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import os

import yaml
from yaml import YAMLError


def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
"""
Safe loading a YAML file to a dict
:param file_path: the path of the YAML file
:param ignore_error:
if True, return empty dict if error occurs and the error will be logged in warning level
if False, raise error if error occurs
:return: a dict of the YAML content
"""
try:
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')

with open(file_path, encoding='utf-8') as file:
try:
return yaml.safe_load(file)
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
except Exception as e:
if ignore_error:
logging.warning(f'Failed to load YAML file {file_path}: {e}')
return {}
else:
raise e
27 changes: 10 additions & 17 deletions api/core/utils/position_helper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr

import yaml
from core.tools.utils.yaml_utils import load_yaml_file


def get_position_map(
Expand All @@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}

with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
position_file_name = os.path.join(folder_path, file_name)
positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {}
index = 0
for _, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
index += 1
return position_map


def sort_by_position_map(
Expand Down
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ select = [
"I001", # unsorted-imports
"I002", # missing-required-import
"UP", # pyupgrade rules
"S506", # unsafe-yaml-load
]
ignore = [
"F403", # undefined-local-with-import-star
Expand Down
Empty file.
29 changes: 29 additions & 0 deletions api/tests/unit_tests/utils/position_helper/test_position_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from core.utils.position_helper import get_position_map


@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text("""
- first
- second
# - commented
- third
- forth
""")
return str(tmp_path)


def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml')
assert position_map == {
'first': 0,
'second': 1,
'third': 2,
'forth': 3,
}
Empty file.
72 changes: 72 additions & 0 deletions api/tests/unit_tests/utils/yaml/test_yaml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from yaml import YAMLError

from core.tools.utils.yaml_utils import load_yaml_file

EXAMPLE_YAML_FILE = 'example_yaml.yaml'
INVALID_YAML_FILE = 'invalid_yaml.yaml'
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'


@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text("""
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
""")
return str(file_path)


@pytest.fixture
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text("""
address:
city: Example City
country: Example Country
age: 30
empty_key:
gender: male
languages:
- Python
- Java
- C++
""")
return str(file_path)


def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=True) == {}
assert load_yaml_file(file_path='', ignore_error=True) == {}

with pytest.raises(FileNotFoundError, match=f'Failed to load YAML file {NON_EXISTING_YAML_FILE}: file not found'):
load_yaml_file(file_path=NON_EXISTING_YAML_FILE)


def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data['age'] == 30
assert yaml_data['gender'] == 'male'
assert yaml_data.get('empty_key') is None
assert yaml_data['address']['city'] == 'Example City'
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}


def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file)

# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}

0 comments on commit 396e40c

Please sign in to comment.