mirror of
https://github.com/labring/FastGPT.git
synced 2025-12-25 20:02:47 +00:00
Merge f1db084b4b into ab743b9358
This commit is contained in:
commit
25bbc13506
|
|
@ -0,0 +1,14 @@
|
|||
__pycache__
|
||||
.benchmarks
|
||||
.idea
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.ropeproject
|
||||
.ruff_cache
|
||||
.venv
|
||||
.vscode
|
||||
.zed
|
||||
|
||||
datasets/bird_minidev
|
||||
.coverage
|
||||
.DS_Store
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
FROM astral/uv:0.8-python3.11-bookworm-slim AS builder
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends gcc libc-dev
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_NO_INSTALLER_METADATA=1 UV_LINK_MODE=copy UV_PYTHON_DOWNLOADS=0
|
||||
|
||||
# Install dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=scripts/install_duckdb_extensions.py,target=install_duckdb_extensions.py \
|
||||
uv sync --frozen --no-install-project --no-dev \
|
||||
&& uv run python install_duckdb_extensions.py
|
||||
|
||||
# clean the cache
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
FROM python:3.11-slim-bookworm AS runtime
|
||||
|
||||
COPY --from=builder --chown=dative:dative /.venv /.venv
|
||||
COPY --from=builder --chown=dative:dative /root/.duckdb /root/.duckdb
|
||||
|
||||
COPY src/dative /dative
|
||||
|
||||
COPY docker/entrypoint.sh /
|
||||
|
||||
ENV PATH="/.venv/bin:$PATH"
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# 安装项目依赖库
|
||||
install:
|
||||
uv sync
|
||||
|
||||
U:
|
||||
uv sync --upgrade
|
||||
|
||||
up package:
|
||||
uv sync --upgrade-package {{ package }}
|
||||
|
||||
# 项目代码风格格式化
|
||||
format:
|
||||
uv run ruff format src tests/unit
|
||||
|
||||
# 执行项目checklist
|
||||
check:
|
||||
uv run ruff check --fix src tests/unit
|
||||
|
||||
sqlfmt: install
|
||||
uv run scripts/sqlfmt.py src/dative/core/data_source/metadata_sql
|
||||
|
||||
# 项目类型检查
|
||||
type:
|
||||
uv run mypy src
|
||||
|
||||
# 执行单元测试
|
||||
test:
|
||||
uv run pytest tests/unit
|
||||
|
||||
# 执行功能测试
|
||||
test_function:
|
||||
uv run pytest tests/function
|
||||
|
||||
# 执行集成测试
|
||||
test_integration:
|
||||
uv run pytest tests/integration
|
||||
|
||||
ci: install format check type test
|
||||
|
||||
dev: install
|
||||
uv run fastapi dev src/dative/main.py
|
||||
|
||||
run: install
|
||||
uv run fastapi run src/dative/main.py
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
输入自然语言,生成sql查询,给出自然语言回答
|
||||
|
||||
## 特性:
|
||||
- 支持mysql、postgresql、sqlite等数据库
|
||||
- 结合duckdb强大本地数据库管理能力,支持本地和S3存储结构化数据,如:xlsx、xls、csv、xlsm、xlsb等
|
||||
- 支持sql基本语法检查和优化
|
||||
- 支持单个数据库查询,不支持跨数据库查询
|
||||
- 仅支持sql查询,不支持更新、删除、插入等语句
|
||||
|
||||
|
||||
## 本地开发
|
||||
|
||||
1、项目管理工具使用 [uv](https://github.com/astral-sh/uv),使用pip安装:
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
2、安装依赖包:
|
||||
```
|
||||
uv sync
|
||||
```
|
||||
|
||||
3、安装duckdb扩展包:
|
||||
```bash
|
||||
uv run python scripts/install_duckdb_extensions.py
|
||||
```
|
||||
|
||||
4、启动服务:
|
||||
```bash
|
||||
uv run fastapi run src/dative/main.py
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Function to check if environment is development (case-insensitive)
|
||||
is_development() {
|
||||
local env
|
||||
env=$(echo "${DATIVE_ENVIRONMENT}" | tr '[:upper:]' '[:lower:]')
|
||||
[[ "$env" == "development" ]]
|
||||
}
|
||||
|
||||
start_arq() {
|
||||
echo "Starting ARQ worker..."
|
||||
# Use exec to replace the shell with arq_worker if running in single process mode
|
||||
if ! is_development; then
|
||||
exec arq_worker
|
||||
else
|
||||
arq_worker &
|
||||
ARQ_PID=$!
|
||||
echo "ARQ worker started with PID: $ARQ_PID"
|
||||
fi
|
||||
}
|
||||
|
||||
start_uvicorn() {
|
||||
echo "Starting Uvicorn server..."
|
||||
# Use PORT environment variable, default to 3000 if not set
|
||||
local port="${PORT:-3000}"
|
||||
# Use exec to replace the shell with uvicorn if running in single process mode
|
||||
if ! is_development; then
|
||||
exec uvicorn dative.main:app --host '0.0.0.0' --port "$port"
|
||||
else
|
||||
uvicorn dative.main:app --host '0.0.0.0' --port "$port" &
|
||||
UVICORN_PID=$!
|
||||
echo "Uvicorn server started with PID: $UVICORN_PID"
|
||||
fi
|
||||
}
|
||||
|
||||
# Start the appropriate service
|
||||
if [[ "${MODE}" == "worker" ]]; then
|
||||
start_arq
|
||||
else
|
||||
start_uvicorn
|
||||
fi
|
||||
|
||||
if is_development; then
|
||||
sleep infinity
|
||||
else
|
||||
echo "Signal handling disabled. Process will replace shell (exec)."
|
||||
# In this case, start_arq or start_uvicorn would have used exec
|
||||
# So we shouldn't reach here unless something went wrong
|
||||
fi
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
[project]
|
||||
name = "dative"
|
||||
version = "0.1.0"
|
||||
description = "Query data from database by natural language"
|
||||
requires-python = "~=3.11.0"
|
||||
dependencies = [
|
||||
"aiobotocore>=2.24.2",
|
||||
"aiosqlite>=0.21.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"cryptography>=45.0.6",
|
||||
"duckdb>=1.3.2",
|
||||
"fastapi<1.0",
|
||||
"fastexcel>=0.15.1",
|
||||
"greenlet>=3.2.4",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"json-repair>=0.50.1",
|
||||
"langchain-core<1.0",
|
||||
"langchain-openai<1.0",
|
||||
"mysql-connector-python>=9.4.0",
|
||||
"orjson>=3.11.2",
|
||||
"pyarrow>=21.0.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"pytz>=2025.2",
|
||||
"sqlglot[rs]==27.12.0",
|
||||
"uvicorn[standard]<1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
only-include = ["src/"]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/"]
|
||||
|
||||
[tool.uv]
|
||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
extra-index-url = ["https://mirrors.aliyun.com/pypi/simple"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"hatch>=1.14.1",
|
||||
"ruff>=0.12.10",
|
||||
"pytest-async>=0.1.1",
|
||||
"pytest-asyncio>=0.26.0",
|
||||
"pytest-cov>=6.3.0",
|
||||
"pytest-mock>=3.14.0",
|
||||
"pytest>=8.3.3",
|
||||
"pytest-benchmark>=4.0.0",
|
||||
"fastapi-cli>=0.0.8",
|
||||
"mypy>=1.17.1",
|
||||
"types-python-dateutil>=2.9.0.20250822",
|
||||
"pyarrow-stubs>=20.0.0.20250825",
|
||||
"asyncpg-stubs>=0.30.2",
|
||||
"boto3-stubs>=1.40.25",
|
||||
"types-aiobotocore[essential]>=2.24.2",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
# checklist
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
preview = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # Pyflakes
|
||||
"W", # Warning
|
||||
"N", # PEP8 Naming
|
||||
"I", # isort
|
||||
"FAST", # FastAPI
|
||||
]
|
||||
# 禁用的规则
|
||||
ignore = []
|
||||
|
||||
[tool.sqruff]
|
||||
output_line_length = 120
|
||||
max_line_length = 120
|
||||
|
||||
# 类型
|
||||
[tool.mypy]
|
||||
mypy_path = "./src"
|
||||
exclude = ["tests/"]
|
||||
|
||||
# 单元测试
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-rA", "--cov=dative"]
|
||||
testpaths = ["tests/unit"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
python_files = ["*.py"]
|
||||
|
||||
# 代码覆盖率
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
parallel = true
|
||||
omit = ["**/__init__.py"]
|
||||
|
||||
[tool.coverage.report]
|
||||
# fail_under = 85
|
||||
show_missing = true
|
||||
sort = "cover"
|
||||
exclude_lines = ["no cov", "if __name__ == .__main__.:"]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import duckdb
|
||||
|
||||
extensions = [
|
||||
"excel",
|
||||
"httpfs"
|
||||
]
|
||||
for ext in extensions:
|
||||
duckdb.install_extension(ext)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from sqlglot import transpile
|
||||
from sqlglot.dialects.dialect import DIALECT_MODULE_NAMES
|
||||
from sqlglot.errors import ParseError
|
||||
|
||||
|
||||
def transpile_sql_file(file: Path, dialect: str) -> None:
|
||||
if dialect not in DIALECT_MODULE_NAMES:
|
||||
raise ValueError(f"Dialect {dialect} not supported")
|
||||
|
||||
if not file.is_file() or not file.suffix == ".sql":
|
||||
print("Please specify a sql file")
|
||||
return
|
||||
|
||||
sql_txt = file.read_text(encoding="utf-8")
|
||||
try:
|
||||
sqls = transpile(sql_txt, read=dialect, write=dialect, pretty=True, indent=4, pad=4)
|
||||
file.write_text(";\n\n".join(sqls) + "\n", encoding="utf-8")
|
||||
except ParseError as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument("files_or_dirs", nargs='+', type=str, default=None, help="files or directories")
|
||||
parser.add_argument('-d', '--dialect', type=str, help='输出文件')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.files_or_dirs:
|
||||
for file_or_dir in args.files_or_dirs:
|
||||
fd = Path(file_or_dir)
|
||||
if fd.is_file():
|
||||
dialect = args.dialect or fd.stem.split(".")[-1]
|
||||
transpile_sql_file(fd, dialect)
|
||||
elif fd.is_dir():
|
||||
for sql_file in fd.glob("*.sql"):
|
||||
dialect = args.dialect or sql_file.stem.split(".")[-1]
|
||||
transpile_sql_file(sql_file, dialect)
|
||||
else:
|
||||
print(f"{file_or_dir} is not a sql file or a directory")
|
||||
sys.exit(3)
|
||||
else:
|
||||
print("Please specify a file or a directory")
|
||||
sys.exit(3)
|
||||
|
|
@ -0,0 +1 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import data_source
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(data_source.router, prefix="/data_source", tags=["data_source"])
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from dative.core.data_source import DatabaseMetadata
|
||||
|
||||
|
||||
class DatabaseConfigBase(BaseModel):
|
||||
host: str = Field(description="host")
|
||||
port: int = Field(default=3306, description="port")
|
||||
username: str = Field(description="username")
|
||||
password: str = Field(description="password")
|
||||
|
||||
|
||||
class SqliteConfig(BaseModel):
|
||||
type: Literal["sqlite"] = "sqlite"
|
||||
db_path: str = Field(description="db_path")
|
||||
|
||||
|
||||
class MysqlConfig(DatabaseConfigBase):
|
||||
type: Literal["mysql", "maria"] = "mysql"
|
||||
db_name: str = Field(description="数据库名称")
|
||||
conn_pool_size: int = Field(default=3, description="数据库连接池大小")
|
||||
|
||||
|
||||
class PostgresConfig(DatabaseConfigBase):
|
||||
type: Literal["postgres"] = "postgres"
|
||||
db_name: str = Field(description="数据库名称")
|
||||
ns_name: str = Field(default="public", alias="schema", description="ns_name")
|
||||
conn_pool_size: int = Field(default=3, description="数据库连接池大小")
|
||||
|
||||
|
||||
class DuckdbLocalStoreConfig(BaseModel):
|
||||
type: Literal["local"] = Field(default="local", description="数据库存储方式")
|
||||
dir_path: str = Field(description="Excel文件目录")
|
||||
|
||||
|
||||
class DuckdbS3StoreConfig(BaseModel):
|
||||
type: Literal["s3"] = Field(default="s3", description="数据库存储方式")
|
||||
host: str = Field(description="host")
|
||||
port: int = Field(default=3306, description="port")
|
||||
access_key: str = Field(description="access_key")
|
||||
secret_key: str = Field(description="secret_key")
|
||||
bucket: str = Field(description="bucket")
|
||||
region: str = Field(default="", description="region")
|
||||
use_ssl: bool = Field(default=False, description="use_ssl")
|
||||
|
||||
|
||||
class DuckdbConfig(BaseModel):
|
||||
type: Literal["duckdb"] = "duckdb"
|
||||
store: DuckdbLocalStoreConfig | DuckdbS3StoreConfig = Field(description="数据库存储方式", discriminator="type")
|
||||
|
||||
|
||||
DataSourceConfig = Union[SqliteConfig, MysqlConfig, PostgresConfig, DuckdbConfig]
|
||||
|
||||
|
||||
def default_api_key() -> SecretStr:
|
||||
if os.getenv("AIPROXY_API_TOKEN"):
|
||||
api_key = SecretStr(str(os.getenv("AIPROXY_API_TOKEN")))
|
||||
else:
|
||||
api_key = SecretStr("")
|
||||
return api_key
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
if os.getenv("AIPROXY_API_ENDPOINT"):
|
||||
base_url = str(os.getenv("AIPROXY_API_ENDPOINT")) + "/v1"
|
||||
else:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
return base_url
|
||||
|
||||
|
||||
class LLMInfo(BaseModel):
|
||||
provider: Literal["openai"] = Field(default="openai", description="LLM提供者")
|
||||
model: str = Field(description="LLM模型名称")
|
||||
api_key: SecretStr = Field(default_factory=default_api_key, description="API密钥", examples=["sk-..."])
|
||||
base_url: str = Field(
|
||||
default_factory=default_base_url, description="API基础URL", examples=["https://api.openai.com/v1"]
|
||||
)
|
||||
temperature: float = Field(default=0.7, description="温度参数")
|
||||
max_tokens: int | None = Field(default=None, description="最大生成长度")
|
||||
extra_body: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class SqlQueryRequest(BaseModel):
|
||||
source_config: DataSourceConfig = Field(description="数据库连接信息", discriminator="type")
|
||||
sql: str
|
||||
|
||||
|
||||
class SqlQueryResponse(BaseModel):
|
||||
cols: list[str] = Field(default=list(), description="查询结果列")
|
||||
data: list[tuple[Any, ...]] = Field(default=list(), description="查询结果数据")
|
||||
|
||||
|
||||
class QueryByNLRequest(BaseModel):
|
||||
source_config: DataSourceConfig = Field(description="数据库连接信息", discriminator="type")
|
||||
query: str = Field(description="用户问题")
|
||||
retrieved_metadata: DatabaseMetadata = Field(description="检索到的元数据")
|
||||
generate_sql_llm: LLMInfo = Field(description="生成sql的LLM配置信息")
|
||||
evaluate_sql_llm: LLMInfo | None = Field(default=None, description="评估sql的LLM配置信息")
|
||||
schema_format: Literal["markdown", "m_schema"] = Field(default="markdown", description="schema转成prompt格式")
|
||||
evidence: str = Field(default="", description="补充说明信息")
|
||||
result_num_limit: int = Field(default=100, description="结果数量限制")
|
||||
|
||||
|
||||
class QueryByNLResponse(BaseModel):
|
||||
answer: str = Field(description="生成的答案")
|
||||
sql: str = Field(description="生成的SQL语句")
|
||||
sql_res: SqlQueryResponse = Field(description="SQL查询结果")
|
||||
input_tokens: int = Field(description="输入token数")
|
||||
output_tokens: int = Field(description="输出token数")
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import traceback
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from dative.api.v1.data_model import (
|
||||
DataSourceConfig,
|
||||
QueryByNLRequest,
|
||||
QueryByNLResponse,
|
||||
SqlQueryRequest,
|
||||
SqlQueryResponse,
|
||||
)
|
||||
from dative.core.agent import Agent
|
||||
from dative.core.data_source import (
|
||||
DatabaseMetadata,
|
||||
DataSourceBase,
|
||||
DBServerVersion,
|
||||
DuckdbLocalStore,
|
||||
DuckdbS3Store,
|
||||
Mysql,
|
||||
Postgres,
|
||||
Sqlite,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def valid_data_source_config(source_config: DataSourceConfig) -> DataSourceBase:
|
||||
ds: DataSourceBase
|
||||
if source_config.type == "mysql" or source_config.type == "maria":
|
||||
ds = Mysql(
|
||||
host=source_config.host,
|
||||
port=source_config.port,
|
||||
username=source_config.username,
|
||||
password=source_config.password,
|
||||
db_name=source_config.db_name,
|
||||
)
|
||||
elif source_config.type == "postgres":
|
||||
ds = Postgres(
|
||||
host=source_config.host,
|
||||
port=source_config.port,
|
||||
username=source_config.username,
|
||||
password=source_config.password,
|
||||
db_name=source_config.db_name,
|
||||
schema=source_config.ns_name,
|
||||
)
|
||||
elif source_config.type == "duckdb":
|
||||
if source_config.store.type == "local":
|
||||
ds = DuckdbLocalStore(dir_path=source_config.store.dir_path)
|
||||
elif source_config.store.type == "s3":
|
||||
ds = DuckdbS3Store(
|
||||
host=source_config.store.host,
|
||||
port=source_config.store.port,
|
||||
access_key=source_config.store.access_key,
|
||||
secret_key=source_config.store.secret_key,
|
||||
bucket=source_config.store.bucket,
|
||||
region=source_config.store.region,
|
||||
use_ssl=source_config.store.use_ssl,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": f"Unsupported duckdb storage type: {source_config.store}",
|
||||
"error": f"Unsupported duckdb storage type: {source_config.store}",
|
||||
},
|
||||
)
|
||||
elif source_config.type == "sqlite":
|
||||
ds = Sqlite(source_config.db_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": f"Unsupported data source types: {source_config.type}",
|
||||
"error": f"Unsupported data source types: {source_config.type}",
|
||||
},
|
||||
)
|
||||
try:
|
||||
await ds.conn_test()
|
||||
return ds
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Connection failed. Please check the connection information.",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conn_test", response_class=ORJSONResponse, dependencies=[Depends(valid_data_source_config)])
|
||||
async def conn_test() -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
@router.post("/get_metadata", response_class=ORJSONResponse)
|
||||
async def get_metadata(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DatabaseMetadata:
|
||||
try:
|
||||
return await ds.aget_metadata()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Failed to obtain database metadata",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
@router.post("/get_metadata_with_value_examples", response_class=ORJSONResponse)
|
||||
async def get_metadata_with_value_example(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DatabaseMetadata:
|
||||
try:
|
||||
return await ds.aget_metadata_with_value_examples()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Failed to obtain database metadata",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
@router.post("/get_server_version", response_class=ORJSONResponse)
|
||||
async def get_server_version(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DBServerVersion:
|
||||
try:
|
||||
return await ds.aget_server_version()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Failed to obtain database version information",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sql_query", response_class=ORJSONResponse)
|
||||
async def sql_query(request: SqlQueryRequest) -> SqlQueryResponse:
|
||||
ds = await valid_data_source_config(source_config=request.source_config)
|
||||
cols, data, err = await ds.aexecute_raw_sql(sql=request.sql)
|
||||
if err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Database query failed",
|
||||
"error": err.msg,
|
||||
},
|
||||
)
|
||||
return SqlQueryResponse(cols=cols, data=data)
|
||||
|
||||
|
||||
@router.post("/query_by_nl", response_class=ORJSONResponse)
|
||||
async def query_by_nl(request: QueryByNLRequest) -> QueryByNLResponse:
|
||||
ds = await valid_data_source_config(source_config=request.source_config)
|
||||
try:
|
||||
server_version = await ds.aget_server_version()
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Connection failed. Please check the connection information.",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
generate_sql_llm = ChatOpenAI(
|
||||
model=request.generate_sql_llm.model,
|
||||
api_key=request.generate_sql_llm.api_key,
|
||||
base_url=request.generate_sql_llm.base_url,
|
||||
stream_usage=True,
|
||||
extra_body=request.generate_sql_llm.extra_body
|
||||
)
|
||||
if request.evaluate_sql_llm:
|
||||
evaluate_sql_llm = ChatOpenAI(
|
||||
model=request.evaluate_sql_llm.model,
|
||||
api_key=request.evaluate_sql_llm.api_key,
|
||||
base_url=request.evaluate_sql_llm.base_url,
|
||||
stream_usage=True,
|
||||
extra_body=request.evaluate_sql_llm.extra_body
|
||||
)
|
||||
else:
|
||||
evaluate_sql_llm = None
|
||||
agent = Agent(
|
||||
ds=ds,
|
||||
db_server_version=server_version,
|
||||
generate_sql_llm=generate_sql_llm,
|
||||
result_num_limit=request.result_num_limit,
|
||||
evaluate_sql_llm=evaluate_sql_llm,
|
||||
)
|
||||
try:
|
||||
answer, sql, cols, data, total_input_tokens, total_output_tokens = await agent.arun(
|
||||
query=request.query, metadata=request.retrieved_metadata, evidence=request.evidence
|
||||
)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"msg": "Database search failed",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
return QueryByNLResponse(
|
||||
answer=answer,
|
||||
sql=sql,
|
||||
sql_res=SqlQueryResponse(cols=cols, data=data),
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_prefix="DATIVE_",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
env_nested_delimiter="__",
|
||||
nested_model_default_partial_update=True,
|
||||
)
|
||||
|
||||
DATABASE_URL: str = ""
|
||||
log_level: str = "INFO"
|
||||
log_format: str = "plain"
|
||||
redis_url: str = ""
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from sqlglot.errors import SqlglotError
|
||||
|
||||
from . import sql_generation as sql_gen
|
||||
from . import sql_res_evaluation as sql_eval
|
||||
from .data_source import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError
|
||||
from .sql_inspection import SQLCheck, SQLOptimization
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
ds: DataSourceBase,
|
||||
db_server_version: DBServerVersion,
|
||||
generate_sql_llm: BaseChatModel,
|
||||
result_num_limit: int = 100,
|
||||
schema_format: Literal["markdown", "m_schema"] = "m_schema",
|
||||
max_attempts: int = 2,
|
||||
evaluate_sql_llm: Optional[BaseChatModel] = None,
|
||||
):
|
||||
self.ds = ds
|
||||
self.generate_sql_llm = generate_sql_llm
|
||||
self.schema_format = schema_format
|
||||
self.db_server_version = db_server_version
|
||||
self.sql_check = SQLCheck(ds.dialect)
|
||||
self.sql_opt = SQLOptimization(dialect=ds.dialect, db_major_version=db_server_version.major)
|
||||
self.result_num_limit = result_num_limit
|
||||
self.max_attempts = max_attempts
|
||||
self.evaluate_sql_llm = evaluate_sql_llm
|
||||
|
||||
async def arun(
|
||||
self, query: str, metadata: DatabaseMetadata, evidence: str = ""
|
||||
) -> tuple[str, str, list[str], list[tuple[Any, ...]], int, int]:
|
||||
answer = None
|
||||
sql = ""
|
||||
cols: list[str] = []
|
||||
data: list[tuple[Any, ...]] = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
attempts = 0
|
||||
|
||||
schema_str = f"# Database server info: {self.ds.dialect} {self.db_server_version}\n"
|
||||
if self.schema_format == "markdown":
|
||||
schema_str += metadata.to_markdown()
|
||||
else:
|
||||
schema_str += metadata.to_m_schema()
|
||||
schema_type = {table.name: table.column_type() for table in metadata.tables if table.enabled}
|
||||
current_sql, error_msg = None, None
|
||||
while attempts < self.max_attempts:
|
||||
answer, sql, input_tokens, output_tokens, error_msg = await sql_gen.arun(
|
||||
query=query,
|
||||
llm=self.generate_sql_llm,
|
||||
db_info=schema_str,
|
||||
evidence=evidence,
|
||||
error_sql=current_sql,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
total_input_tokens += input_tokens
|
||||
total_output_tokens += output_tokens
|
||||
if answer is not None:
|
||||
return answer, sql, cols, data, total_input_tokens, total_output_tokens
|
||||
elif error_msg is None and sql:
|
||||
try:
|
||||
sql_exp = self.sql_check.syntax_valid(sql)
|
||||
if not self.sql_check.is_query(sql_exp):
|
||||
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens
|
||||
|
||||
sql = self.sql_opt.arun(sql_exp, schema_type=schema_type, result_num_limit=self.result_num_limit)
|
||||
except SqlglotError as e:
|
||||
error_msg = f"SQL语法错误:{e}"
|
||||
|
||||
if not error_msg:
|
||||
cols, data, err = await self.ds.aexecute_raw_sql(sql)
|
||||
if err:
|
||||
if err.error_type != SQLError.SyntaxError:
|
||||
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens
|
||||
else:
|
||||
error_msg = err.msg
|
||||
|
||||
if not error_msg and self.evaluate_sql_llm:
|
||||
answer, input_tokens, output_tokens, error_msg = await sql_eval.arun(
|
||||
query=query,
|
||||
llm=self.evaluate_sql_llm,
|
||||
db_info=schema_str,
|
||||
sql=sql,
|
||||
res_rows=data,
|
||||
evidence=evidence,
|
||||
)
|
||||
total_input_tokens += input_tokens
|
||||
total_output_tokens += output_tokens
|
||||
|
||||
if not error_msg:
|
||||
break
|
||||
attempts += 1
|
||||
current_sql = sql
|
||||
|
||||
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .base import (
|
||||
DatabaseMetadata,
|
||||
DataSourceBase,
|
||||
DBException,
|
||||
DBServerVersion,
|
||||
DBTable,
|
||||
SQLError,
|
||||
SQLException,
|
||||
TableColumn,
|
||||
TableForeignKey,
|
||||
)
|
||||
from .duckdb_ds import DuckdbLocalStore, DuckdbS3Store
|
||||
from .mysql_ds import Mysql
|
||||
from .postgres_ds import Postgres
|
||||
from .sqlite_ds import Sqlite
|
||||
|
||||
__all__ = [
|
||||
"DataSourceBase",
|
||||
"DatabaseMetadata",
|
||||
"DBServerVersion",
|
||||
"TableColumn",
|
||||
"TableForeignKey",
|
||||
"DBTable",
|
||||
"SQLError",
|
||||
"SQLException",
|
||||
"Mysql",
|
||||
"Postgres",
|
||||
"DBException",
|
||||
"DuckdbLocalStore",
|
||||
"DuckdbS3Store",
|
||||
"Sqlite",
|
||||
]
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlglot import exp
|
||||
|
||||
from dative.core.utils import convert_value2str, is_date, is_email, is_number, is_valid_uuid, truncate_text
|
||||
|
||||
|
||||
class TableColumn(BaseModel):
|
||||
name: str = Field(description="列名")
|
||||
type: str = Field(description="数据类型")
|
||||
comment: str = Field(default="", description="描述信息")
|
||||
auto_increment: bool = Field(default=False, description="是否自增")
|
||||
nullable: bool = Field(default=True, description="是否允许为空")
|
||||
default: Any = Field(default=None, description="默认值")
|
||||
examples: list[Any] = Field(default_factory=list, description="值样例")
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
value_index: bool = Field(default=False, description="是否启用值索引")
|
||||
|
||||
|
||||
class ConstraintKey(BaseModel):
|
||||
name: str = Field(description="约束名")
|
||||
column: str = Field(description="约束字段名")
|
||||
|
||||
|
||||
class TableForeignKey(ConstraintKey):
|
||||
referenced_schema: str = Field(description="引用schema")
|
||||
referenced_table: str = Field(description="引用表名")
|
||||
referenced_column: str = Field(description="引用字段名")
|
||||
|
||||
|
||||
class DBTable(BaseModel):
|
||||
name: str = Field(description="表名")
|
||||
ns_name: str | None = Field(default=None, alias="schema", description="ns_name")
|
||||
comment: str = Field(default="", description="描述信息")
|
||||
columns: dict[str, TableColumn] = Field(default_factory=dict, description="列")
|
||||
primary_keys: list[str] = Field(default_factory=list, description="主键")
|
||||
foreign_keys: list[TableForeignKey] = Field(default_factory=list, description="外键")
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
|
||||
def column_type(self, case_insensitive: bool = True) -> dict[str, str]:
|
||||
schema_type: dict[str, str] = dict()
|
||||
for col in self.columns.values():
|
||||
schema_type[col.name] = col.type
|
||||
if case_insensitive:
|
||||
schema_type[col.name.title()] = col.type
|
||||
schema_type[col.name.lower()] = col.type
|
||||
schema_type[col.name.upper()] = col.type
|
||||
return schema_type
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
s = f"""
|
||||
### Table name: {self.name}
|
||||
{self.comment}
|
||||
#### Columns:
|
||||
|Name|Description|Type|Examples|
|
||||
|---|---|---|---|
|
||||
"""
|
||||
for col in self.columns.values():
|
||||
if not col.enabled:
|
||||
continue
|
||||
|
||||
example_values = "<br>".join([convert_value2str(v) for v in col.examples])
|
||||
s += f"|{col.name}|{col.comment}|{col.type}|{example_values}|\n"
|
||||
|
||||
if self.primary_keys:
|
||||
s += f"#### Primary Keys: {tuple(self.primary_keys)}\n"
|
||||
if self.foreign_keys:
|
||||
fk_str = "#### Foreign Keys: \n"
|
||||
for fk in self.foreign_keys:
|
||||
# 只显示同一个schema的外键
|
||||
if self.ns_name is None or self.ns_name == fk.referenced_schema:
|
||||
fk_str += f" - {fk.column} -> {fk.referenced_table}.{fk.referenced_column}\n"
|
||||
s += fk_str
|
||||
|
||||
return s
|
||||
|
||||
def to_m_schema(self, db_name: str | None = None) -> str:
|
||||
# XiYanSQL: https://github.com/XGenerationLab/M-Schema
|
||||
output = []
|
||||
if db_name:
|
||||
table_comment = f"# Table: {db_name}.{self.name}"
|
||||
else:
|
||||
table_comment = f"# Table: {self.name}"
|
||||
if self.comment:
|
||||
table_comment += f", {self.comment}"
|
||||
output.append(table_comment)
|
||||
|
||||
field_lines = []
|
||||
for col in self.columns.values():
|
||||
if not col.enabled:
|
||||
continue
|
||||
|
||||
field_line = f"({col.name}: {col.type.upper()}"
|
||||
if col.name in self.primary_keys:
|
||||
field_line += ", Primary Key"
|
||||
if col.comment:
|
||||
field_line += f", {col.comment.strip()}"
|
||||
if col.examples:
|
||||
example_values = ", ".join([convert_value2str(v) for v in col.examples])
|
||||
field_line += f", Examples: [{example_values}]"
|
||||
field_line += ")"
|
||||
field_lines.append(field_line)
|
||||
|
||||
output.append("[")
|
||||
output.append(",\n".join(field_lines))
|
||||
output.append("]")
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
class SQLError(StrEnum):
|
||||
EmptySQL = auto()
|
||||
SyntaxError = auto()
|
||||
NotAllowedOperation = auto()
|
||||
DBError = auto()
|
||||
UnknownError = auto()
|
||||
|
||||
|
||||
class SQLException(BaseModel):
|
||||
error_type: SQLError
|
||||
msg: str
|
||||
code: int = Field(default=0, description="错误码")
|
||||
|
||||
|
||||
class DatabaseMetadata(BaseModel):
|
||||
name: str = Field(description="数据库名")
|
||||
comment: str = Field(default="", description="描述信息")
|
||||
tables: list[DBTable] = Field(default_factory=list, description="表")
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
s = f"""
|
||||
# Database name: {self.name}
|
||||
{self.comment}
|
||||
## Tables:
|
||||
"""
|
||||
for table in self.tables:
|
||||
if table.enabled:
|
||||
s += f"{table.to_markdown()}\n"
|
||||
return s
|
||||
|
||||
def to_m_schema(self) -> str:
|
||||
output = [f"【DB_ID】 {self.name}"]
|
||||
if self.comment:
|
||||
output.append(self.comment)
|
||||
output.append("【Schema】")
|
||||
output.append("schema format: (column name: data type, is primary key, comment, examples)")
|
||||
fks = []
|
||||
for t in self.tables:
|
||||
if not t.enabled:
|
||||
continue
|
||||
output.append(t.to_m_schema(self.name))
|
||||
|
||||
for fk in t.foreign_keys:
|
||||
# 只显示同一个schema的外键
|
||||
if t.ns_name is None or t.ns_name == fk.referenced_schema:
|
||||
fks.append(f"{t.name}.{fk.column}={fk.referenced_table}.{fk.referenced_column}")
|
||||
|
||||
if fks:
|
||||
output.append("【Foreign Keys】")
|
||||
output.extend(fks)
|
||||
output.append("\n")
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
class DBServerVersion(BaseModel):
|
||||
major: int = Field(description="主版本号")
|
||||
minor: int = Field(description="次版本号")
|
||||
patch: int | None = Field(default=None, description="修订号/补丁号")
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = f"{self.major}.{self.minor}"
|
||||
if self.patch is None:
|
||||
return s
|
||||
return s + f".{self.patch}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class DBException(BaseModel):
|
||||
code: int = Field(description="异常码")
|
||||
msg: str = Field(description="异常信息")
|
||||
|
||||
|
||||
class DataSourceBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def dialect(self) -> str:
|
||||
"""return the type of data source"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def string_types(self) -> set[str]:
|
||||
"""return the string types of data source"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def json_array_agg_func(self) -> str:
|
||||
"""
|
||||
获取JSON数组聚合函数的SQL表达式
|
||||
|
||||
Returns:
|
||||
str: 返回适用于当前数据库类型的JSON数组聚合函数名称
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def conn_test(self) -> bool:
|
||||
await self.aexecute_raw_sql("SELECT 1")
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def aget_server_version(self) -> DBServerVersion:
|
||||
"""get server version"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_metadata(self) -> DatabaseMetadata:
|
||||
"""get database metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
|
||||
""""""
|
||||
|
||||
@staticmethod
|
||||
async def distinct_values_exp(table_name: str, col_name: str, limit: int = 3) -> exp.Expression:
|
||||
sql_exp = (
|
||||
exp.select(exp.column(col_name))
|
||||
.distinct()
|
||||
.from_(exp.to_identifier(table_name))
|
||||
.where(exp.not_(exp.column(col_name).is_(exp.null())))
|
||||
.limit(limit)
|
||||
)
|
||||
return sql_exp
|
||||
|
||||
async def aget_metadata_with_value_examples(
|
||||
self, value_num: int = 3, value_str_max_length: int = 40
|
||||
) -> DatabaseMetadata:
|
||||
"""
|
||||
获取数据库元数据及示例值信息
|
||||
|
||||
Args:
|
||||
value_num (int): 每个字段需要获取的示例值数量,默认为3
|
||||
value_str_max_length (int): 字符串类型示例值的最大长度,默认为40
|
||||
|
||||
Returns:
|
||||
DatabaseMetadata: 包含数据库元数据和示例值信息的对象
|
||||
"""
|
||||
db_metadata = await self.aget_metadata()
|
||||
|
||||
for t in db_metadata.tables:
|
||||
col_exp = []
|
||||
for col in t.columns.values():
|
||||
col_exp.append(await self.agg_distinct_values(t.name, col.name, value_num))
|
||||
|
||||
sql = exp.union(*col_exp, distinct=False).sql(dialect=self.dialect)
|
||||
_, rows, _ = await self.aexecute_raw_sql(sql)
|
||||
rows = cast(list[tuple[str, str]], rows)
|
||||
for i in range(len(rows)):
|
||||
col_name = rows[i][0]
|
||||
if rows[i][1]:
|
||||
examples = orjson.loads(cast(str, rows[i][1]))
|
||||
str_examples = [
|
||||
truncate_text(convert_value2str(v), max_length=value_str_max_length) for v in examples
|
||||
]
|
||||
t.columns[col_name].examples = str_examples
|
||||
if (
|
||||
t.columns[col_name].type in self.string_types
|
||||
and not is_valid_uuid(str_examples[0])
|
||||
and not is_number(str_examples[0])
|
||||
and not is_date(str_examples[0])
|
||||
and not is_email(str_examples[0])
|
||||
):
|
||||
t.columns[col_name].value_index = True
|
||||
|
||||
return db_metadata
|
||||
|
||||
async def agg_distinct_values(self, table_name: str, col_name: str, limit: int = 3) -> exp.Expression:
|
||||
dis_exp = await self.distinct_values_exp(table_name, col_name, limit)
|
||||
sql_exp = exp.select(
|
||||
exp.Alias(this=exp.Literal.string(col_name), alias="name"),
|
||||
exp.Alias(this=exp.func(self.json_array_agg_func, exp.column(col_name)), alias="examples"),
|
||||
).from_(exp.Subquery(this=dis_exp, alias="t"))
|
||||
return sql_exp
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import duckdb
|
||||
import orjson
|
||||
import pyarrow as pa
|
||||
from aiobotocore.session import get_session
|
||||
from botocore.exceptions import ClientError
|
||||
from fastexcel import read_excel
|
||||
|
||||
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
|
||||
|
||||
with open(Path(__file__).parent / "metadata_sql" / "duckdb.sql", encoding="utf-8") as f:
|
||||
METADATA_SQL = f.read()
|
||||
|
||||
supported_file_extensions = {"xlsx", "xls", "xlsm", "xlsb", "csv", "json", "parquet"}
|
||||
|
||||
|
||||
class DuckdbBase(DataSourceBase, ABC):
|
||||
"""
|
||||
duckdb 和 polars对比:polars目前不支持对excel的sql查询优化( 懒加载api:pl.scan_ )
|
||||
在文件较大,数据量较多时,通过sql查询优化不用把数据全部加载到内存进行查询。
|
||||
"""
|
||||
|
||||
def __init__(self, db_name: str):
|
||||
self.db_name = db_name
|
||||
self.conn = self.get_conn()
|
||||
self.loaded = False
|
||||
|
||||
@abstractmethod
|
||||
def get_conn(self) -> duckdb.DuckDBPyConnection:
|
||||
""""""
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
return "duckdb"
|
||||
|
||||
@property
|
||||
def string_types(self) -> set[str]:
|
||||
return {"VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING"}
|
||||
|
||||
@property
|
||||
def json_array_agg_func(self) -> str:
|
||||
return "JSON_GROUP_ARRAY"
|
||||
|
||||
async def aget_server_version(self) -> DBServerVersion:
|
||||
version = [int(i) for i in duckdb.__version__.split(".") if i.isdigit()]
|
||||
server_version = DBServerVersion(major=version[0], minor=version[1])
|
||||
if len(version) > 2:
|
||||
server_version.patch = version[2]
|
||||
return server_version
|
||||
|
||||
@abstractmethod
|
||||
async def read_files(self) -> list[tuple[str, str]]:
|
||||
""""""
|
||||
|
||||
async def load_data(self) -> None:
|
||||
if self.loaded:
|
||||
return
|
||||
|
||||
# 加载新数据,重置连接
|
||||
self.conn = self.get_conn()
|
||||
files = await self.read_files()
|
||||
for name, file_path in files:
|
||||
extension = Path(file_path).suffix.split(".")[-1].lower()
|
||||
if extension == "xlsx":
|
||||
sql = f"select * from read_xlsx('{file_path}')"
|
||||
self.conn.register(f"{name}", self.conn.sql(sql))
|
||||
elif extension == "csv":
|
||||
sql = f"select * from read_csv('{file_path}')"
|
||||
self.conn.register(f"{name}", self.conn.sql(sql))
|
||||
elif extension == "json":
|
||||
sql = f"select * from read_json('{file_path}')"
|
||||
self.conn.register(f"{name}", self.conn.sql(sql))
|
||||
elif extension == "parquet":
|
||||
sql = f"select * from read_parquet('{file_path}')"
|
||||
self.conn.register(f"{name}", self.conn.sql(sql))
|
||||
else:
|
||||
wb = read_excel(file_path)
|
||||
record_batch = wb.load_sheet(wb.sheet_names[0]).to_arrow()
|
||||
self.conn.register(f"{name}", pa.Table.from_batches([record_batch]))
|
||||
self.loaded = True
|
||||
|
||||
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
|
||||
try:
|
||||
df: duckdb.DuckDBPyRelation = self.conn.sql(sql)
|
||||
return df.columns, df.fetchall(), None
|
||||
except duckdb.Error as e:
|
||||
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))
|
||||
|
||||
async def aget_metadata(self) -> DatabaseMetadata:
|
||||
await self.load_data()
|
||||
|
||||
_, rows, err = await self.aexecute_raw_sql(METADATA_SQL.format(db_name=self.db_name))
|
||||
if err:
|
||||
raise ConnectionError(err.msg)
|
||||
if not rows or not rows[0][1]:
|
||||
return DatabaseMetadata(name=self.db_name)
|
||||
metadata = DatabaseMetadata.model_validate({
|
||||
"name": self.db_name,
|
||||
"tables": orjson.loads(cast(str, rows[0][1])),
|
||||
})
|
||||
return metadata
|
||||
|
||||
|
||||
class DuckdbLocalStore(DuckdbBase):
|
||||
def __init__(self, dir_path: str | Path):
|
||||
if isinstance(dir_path, str):
|
||||
dir_path = Path(dir_path)
|
||||
super().__init__(dir_path.stem)
|
||||
|
||||
self.dir_path = dir_path
|
||||
self.conn = self.get_conn()
|
||||
self.loaded = False
|
||||
|
||||
def get_conn(self) -> duckdb.DuckDBPyConnection:
|
||||
conn = duckdb.connect(config={"allow_unsigned_extensions": True})
|
||||
conn.load_extension("excel")
|
||||
return conn
|
||||
|
||||
async def conn_test(self) -> bool:
|
||||
if self.dir_path.is_dir():
|
||||
await self.load_data()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def read_files(self) -> list[tuple[str, str]]:
|
||||
excel_files: list[tuple[str, str]] = []
|
||||
for file_path in self.dir_path.glob("*"):
|
||||
if file_path.is_file() and file_path.suffix.split(".")[-1].lower() in supported_file_extensions:
|
||||
excel_files.append((file_path.stem, str(file_path.absolute())))
|
||||
|
||||
return excel_files
|
||||
|
||||
|
||||
class DuckdbS3Store(DuckdbBase):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
bucket: str,
|
||||
region: str = "",
|
||||
use_ssl: bool = False,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.endpoint = f"{self.host}:{self.port}"
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.bucket = bucket
|
||||
self.db_name = bucket
|
||||
self.region = region
|
||||
self.use_ssl = use_ssl
|
||||
self.conn = self.get_conn()
|
||||
super().__init__(self.db_name)
|
||||
|
||||
if self.use_ssl:
|
||||
self.endpoint_url = f"https://{self.endpoint}"
|
||||
else:
|
||||
self.endpoint_url = f"http://{self.endpoint}"
|
||||
self.session = get_session()
|
||||
|
||||
def get_conn(self) -> duckdb.DuckDBPyConnection:
|
||||
conn = duckdb.connect()
|
||||
conn.load_extension("httpfs")
|
||||
sql = f"""
|
||||
CREATE OR REPLACE SECRET secret (
|
||||
TYPE s3,
|
||||
ENDPOINT '{self.endpoint}',
|
||||
KEY_ID '{self.access_key}',
|
||||
SECRET '{self.secret_key}',
|
||||
USE_SSL {orjson.dumps(self.use_ssl).decode()},
|
||||
URL_STYLE 'path',
|
||||
REGION '{self.region!r}'
|
||||
);
|
||||
"""
|
||||
conn.execute(sql)
|
||||
return conn
|
||||
|
||||
async def conn_test(self) -> bool:
|
||||
try:
|
||||
async with self.session.create_client(
|
||||
service_name="s3",
|
||||
endpoint_url="http://localhost:9000",
|
||||
aws_secret_access_key=self.secret_key,
|
||||
aws_access_key_id=self.access_key,
|
||||
region_name=self.region,
|
||||
) as client:
|
||||
await client.head_bucket(Bucket=self.bucket)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
async def read_files(self) -> list[tuple[str, str]]:
|
||||
excel_files: list[tuple[str, str]] = []
|
||||
async with self.session.create_client(
|
||||
service_name="s3",
|
||||
endpoint_url="http://localhost:9000",
|
||||
aws_secret_access_key=self.secret_key,
|
||||
aws_access_key_id=self.access_key,
|
||||
region_name=self.region,
|
||||
) as client:
|
||||
paginator = client.get_paginator("list_objects_v2")
|
||||
async for result in paginator.paginate(Bucket=self.bucket):
|
||||
for obj in result.get("Contents", []):
|
||||
if obj["Key"].split(".")[-1].lower() in supported_file_extensions:
|
||||
excel_files.append((obj["Key"].split(".")[0], f"s3://{self.bucket}/{obj['Key']}"))
|
||||
return excel_files
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
SELECT
|
||||
'{db_name}' AS db_name,
|
||||
JSON_GROUP_ARRAY(JSON_OBJECT('name', t.table_name, 'columns', t.columns)) AS tables
|
||||
FROM (
|
||||
SELECT
|
||||
t.table_name,
|
||||
JSON_GROUP_OBJECT(
|
||||
t.column_name,
|
||||
JSON_OBJECT(
|
||||
'name', t.column_name,
|
||||
'default', t.column_default,
|
||||
'nullable', CASE WHEN t.is_nullable = 'YES' THEN TRUE ELSE FALSE END,
|
||||
'type', UPPER(t.data_type)
|
||||
)
|
||||
) AS columns
|
||||
FROM information_schema.columns AS t
|
||||
GROUP BY
|
||||
t.table_name
|
||||
) AS t
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
SELECT
|
||||
t.table_schema,
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'name', t.table_name,
|
||||
'columns', t.columns,
|
||||
'primary_keys', COALESCE(t.primary_keys, JSON_ARRAY()),
|
||||
'foreign_keys', COALESCE(t.foreign_keys, JSON_ARRAY()),
|
||||
'comment', t.table_comment
|
||||
)
|
||||
) AS tables
|
||||
FROM (
|
||||
SELECT
|
||||
t1.table_schema,
|
||||
t1.table_name,
|
||||
t1.table_comment,
|
||||
t2.columns,
|
||||
t3.primary_keys,
|
||||
t4.foreign_keys
|
||||
FROM information_schema.tables AS t1
|
||||
JOIN (
|
||||
SELECT
|
||||
t.table_schema,
|
||||
t.table_name,
|
||||
JSON_OBJECTAGG(
|
||||
t.column_name, JSON_OBJECT(
|
||||
'name', t.column_name,
|
||||
'default', t.column_default,
|
||||
'nullable', CASE WHEN t.is_nullable = 'YES' THEN CAST(TRUE AS JSON) ELSE CAST(FALSE AS JSON) END,
|
||||
'type', UPPER(t.data_type),
|
||||
'auto_increment', CASE
|
||||
WHEN t.extra = 'auto_increment'
|
||||
THEN CAST(TRUE AS JSON)
|
||||
ELSE CAST(FALSE AS JSON)
|
||||
END,
|
||||
'comment', t.column_comment
|
||||
)
|
||||
) AS columns
|
||||
FROM information_schema.columns AS t
|
||||
WHERE
|
||||
table_schema = '{db_name}'
|
||||
GROUP BY
|
||||
t.table_schema,
|
||||
t.table_name
|
||||
) AS t2
|
||||
ON t1.table_schema = t2.table_schema AND t1.table_name = t2.table_name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
kcu.table_schema,
|
||||
kcu.table_name,
|
||||
JSON_ARRAYAGG(kcu.column_name) AS primary_keys
|
||||
FROM information_schema.key_column_usage AS kcu
|
||||
JOIN information_schema.table_constraints AS tc
|
||||
ON kcu.table_schema = '{db_name}'
|
||||
AND kcu.constraint_name = tc.constraint_name
|
||||
AND kcu.constraint_schema = tc.constraint_schema
|
||||
AND kcu.table_name = tc.table_name
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
GROUP BY
|
||||
kcu.table_schema,
|
||||
kcu.table_name
|
||||
) AS t3
|
||||
ON t1.table_schema = t3.table_schema AND t1.table_name = t3.table_name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
kcu.table_schema,
|
||||
kcu.table_name,
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'name', kcu.constraint_name,
|
||||
'column', kcu.column_name,
|
||||
'referenced_schema', kcu.referenced_table_schema,
|
||||
'referenced_table', kcu.referenced_table_name,
|
||||
'referenced_column', kcu.referenced_column_name
|
||||
)
|
||||
) AS foreign_keys
|
||||
FROM information_schema.key_column_usage AS kcu
|
||||
JOIN information_schema.table_constraints AS tc
|
||||
ON kcu.table_schema = '{db_name}'
|
||||
AND kcu.constraint_name = tc.constraint_name
|
||||
AND kcu.constraint_schema = tc.constraint_schema
|
||||
AND kcu.table_name = tc.table_name
|
||||
AND tc.constraint_type = 'FOREIGN KEY'
|
||||
GROUP BY
|
||||
kcu.table_schema,
|
||||
kcu.table_name
|
||||
) AS t4
|
||||
ON t1.table_schema = t4.table_schema AND t1.table_name = t4.table_name
|
||||
) AS t
|
||||
GROUP BY
|
||||
t.table_schema
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
SELECT
|
||||
t.table_schema,
|
||||
JSON_AGG(
|
||||
JSON_BUILD_OBJECT(
|
||||
'name',
|
||||
t.table_name,
|
||||
'schema',
|
||||
t.table_schema,
|
||||
'columns',
|
||||
t.columns,
|
||||
'primary_keys',
|
||||
COALESCE(t.primary_keys, CAST('[]' AS JSON)),
|
||||
'foreign_keys',
|
||||
COALESCE(t.foreign_keys, CAST('[]' AS JSON)),
|
||||
'comment',
|
||||
t.comment
|
||||
)
|
||||
) AS tables
|
||||
FROM (
|
||||
SELECT
|
||||
t1.table_schema,
|
||||
t1.table_name,
|
||||
t1.comment,
|
||||
t2.columns,
|
||||
t3.primary_keys,
|
||||
t4.foreign_keys
|
||||
FROM (
|
||||
SELECT
|
||||
n.nspname AS table_schema,
|
||||
c.relname AS table_name,
|
||||
COALESCE(OBJ_DESCRIPTION(c.oid), '') AS comment
|
||||
FROM pg_class AS c
|
||||
JOIN pg_namespace AS n
|
||||
ON n.oid = c.relnamespace AND c.relkind = 'r' AND n.nspname = '{schema}'
|
||||
) AS t1
|
||||
JOIN (
|
||||
SELECT
|
||||
t.table_schema,
|
||||
t.table_name,
|
||||
JSON_OBJECT_AGG(
|
||||
t.column_name,
|
||||
JSON_BUILD_OBJECT(
|
||||
'name',
|
||||
t.column_name,
|
||||
'default',
|
||||
t.column_default,
|
||||
'nullable',
|
||||
CASE WHEN t.is_nullable = 'YES' THEN TRUE ELSE FALSE END,
|
||||
'type',
|
||||
UPPER(t.data_type),
|
||||
'auto_increment',
|
||||
t.auto_increment,
|
||||
'comment',
|
||||
COALESCE(t.comment, '')
|
||||
)
|
||||
) AS columns
|
||||
FROM (
|
||||
SELECT
|
||||
c.table_schema,
|
||||
c.table_name,
|
||||
c.column_name,
|
||||
c.data_type,
|
||||
c.column_default,
|
||||
c.is_nullable,
|
||||
COALESCE(coldesc.comment, '') AS comment,
|
||||
CASE
|
||||
WHEN c.is_identity = 'YES'
|
||||
OR c.column_default LIKE 'nextval%'
|
||||
OR c.column_default LIKE 'nextval(%'
|
||||
THEN TRUE
|
||||
ELSE FALSE
|
||||
END AS auto_increment
|
||||
FROM information_schema.columns AS c
|
||||
JOIN (
|
||||
SELECT
|
||||
n.nspname AS table_schema,
|
||||
c.relname AS table_name,
|
||||
a.attname AS column_name,
|
||||
COL_DESCRIPTION(c.oid, a.attnum) AS comment
|
||||
FROM pg_class AS c
|
||||
JOIN pg_namespace AS n
|
||||
ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute AS a
|
||||
ON a.attrelid = c.oid
|
||||
WHERE
|
||||
a.attnum > 0 AND NOT a.attisdropped AND n.nspname = '{schema}'
|
||||
) AS coldesc
|
||||
ON c.table_schema = '{schema}'
|
||||
AND c.table_schema = coldesc.table_schema
|
||||
AND c.table_name = coldesc.table_name
|
||||
AND c.column_name = coldesc.column_name
|
||||
) AS t
|
||||
GROUP BY
|
||||
t.table_schema,
|
||||
t.table_name
|
||||
) AS t2
|
||||
ON t1.table_schema = t2.table_schema AND t1.table_name = t2.table_name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) AS table_schema,
|
||||
CASE
|
||||
WHEN POSITION('.' IN CAST(CAST(conrelid AS REGCLASS) AS TEXT)) > 0
|
||||
THEN SPLIT_PART(CAST(CAST(conrelid AS REGCLASS) AS TEXT), '.', 2)
|
||||
ELSE CAST(CAST(conrelid AS REGCLASS) AS TEXT)
|
||||
END AS table_name,
|
||||
TO_JSON(STRING_TO_ARRAY(SUBSTRING(PG_GET_CONSTRAINTDEF(oid) FROM '\((.*?)\)'), ',')) AS primary_keys
|
||||
FROM pg_constraint
|
||||
WHERE
|
||||
contype = 'p' AND CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) = '{schema}'
|
||||
) AS t3
|
||||
ON t1.table_schema = t3.table_schema AND t1.table_name = t3.table_name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
t.table_schema,
|
||||
t.table_name,
|
||||
JSON_AGG(
|
||||
JSON_BUILD_OBJECT(
|
||||
'name',
|
||||
t.name,
|
||||
'column',
|
||||
t.column_name,
|
||||
'referenced_schema',
|
||||
t.referenced_table_schema,
|
||||
'referenced_table',
|
||||
t.referenced_table_name,
|
||||
'referenced_column',
|
||||
t.referenced_column_name
|
||||
)
|
||||
) AS foreign_keys
|
||||
FROM (
|
||||
SELECT
|
||||
c.conname AS name,
|
||||
n.nspname AS table_schema,
|
||||
CASE
|
||||
WHEN POSITION('.' IN CAST(CAST(conrelid AS REGCLASS) AS TEXT)) > 0
|
||||
THEN SPLIT_PART(CAST(CAST(conrelid AS REGCLASS) AS TEXT), '.', 2)
|
||||
ELSE CAST(CAST(conrelid AS REGCLASS) AS TEXT)
|
||||
END AS TABLE_NAME,
|
||||
A.attname AS COLUMN_NAME,
|
||||
nr.nspname AS referenced_table_schema,
|
||||
CASE
|
||||
WHEN POSITION('.' IN CAST(CAST(confrelid AS REGCLASS) AS TEXT)) > 0
|
||||
THEN SPLIT_PART(CAST(CAST(confrelid AS REGCLASS) AS TEXT), '.', 2)
|
||||
ELSE CAST(CAST(confrelid AS REGCLASS) AS TEXT)
|
||||
END AS referenced_table_name,
|
||||
af.attname AS referenced_column_name
|
||||
FROM pg_constraint AS c
|
||||
JOIN pg_attribute AS a
|
||||
ON a.attnum = ANY(
|
||||
c.conkey
|
||||
) AND a.attrelid = c.conrelid
|
||||
JOIN pg_class AS cl
|
||||
ON cl.oid = c.conrelid
|
||||
JOIN pg_namespace AS n
|
||||
ON n.oid = cl.relnamespace
|
||||
JOIN pg_attribute AS af
|
||||
ON af.attnum = ANY(
|
||||
c.confkey
|
||||
) AND af.attrelid = c.confrelid
|
||||
JOIN pg_class AS clf
|
||||
ON clf.oid = c.confrelid
|
||||
JOIN pg_namespace AS nr
|
||||
ON nr.oid = clf.relnamespace
|
||||
WHERE
|
||||
c.contype = 'f' AND CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) = '{schema}'
|
||||
) AS t
|
||||
GROUP BY
|
||||
t.table_schema,
|
||||
t.table_name
|
||||
) AS t4
|
||||
ON t1.table_schema = t4.table_schema AND t1.table_name = t4.table_name
|
||||
) AS t
|
||||
GROUP BY
|
||||
t.table_schema
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
SELECT
|
||||
JSON_GROUP_ARRAY(
|
||||
JSON_OBJECT(
|
||||
'name', m.name,
|
||||
'schema', 'main',
|
||||
'columns', JSON(t1.columns),
|
||||
'primary_keys', JSON(t2.primary_keys),
|
||||
'foreign_keys', COALESCE(JSON(t3.foreign_keys), JSON_ARRAY())
|
||||
)
|
||||
) AS tables
|
||||
FROM sqlite_master AS m
|
||||
JOIN (
|
||||
SELECT
|
||||
m.name,
|
||||
JSON_GROUP_OBJECT(
|
||||
p.name,
|
||||
JSON_OBJECT(
|
||||
'name', p.name,
|
||||
'type', CASE
|
||||
WHEN INSTR(UPPER(p.type), '(') > 0
|
||||
THEN SUBSTRING(UPPER(p.type), 1, INSTR(UPPER(p.type), '(') - 1)
|
||||
ELSE UPPER(p.type)
|
||||
END,
|
||||
'nullable', (
|
||||
CASE WHEN p."notnull" = 0 THEN TRUE ELSE FALSE END
|
||||
),
|
||||
'default', p.dflt_value
|
||||
)
|
||||
) AS columns
|
||||
FROM sqlite_master AS m
|
||||
JOIN PRAGMA_TABLE_INFO(m.name) AS p
|
||||
ON m.type = 'table'
|
||||
GROUP BY
|
||||
m.name
|
||||
) AS t1
|
||||
ON m.name = t1.name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
m.name,
|
||||
JSON_GROUP_ARRAY(p.name) AS primary_keys
|
||||
FROM sqlite_master AS m
|
||||
JOIN PRAGMA_TABLE_INFO(m.name) AS p
|
||||
ON m.type = 'table' AND p.pk > 0
|
||||
GROUP BY
|
||||
m.name
|
||||
) AS t2
|
||||
ON m.name = t2.name
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
m.name,
|
||||
JSON_GROUP_ARRAY(
|
||||
JSON_OBJECT(
|
||||
'name', 'fk_' || m.tbl_name || '_' || fk."from" || '_' || fk."table" || '_' || fk."to",
|
||||
'column', fk."from",
|
||||
'referenced_schema', '',
|
||||
'referenced_table', fk."table",
|
||||
'referenced_column', fk."to"
|
||||
)
|
||||
) AS foreign_keys
|
||||
FROM sqlite_master AS m
|
||||
JOIN PRAGMA_FOREIGN_KEY_LIST(m.name) AS fk
|
||||
ON m.type = 'table'
|
||||
GROUP BY
|
||||
m.tbl_name
|
||||
) AS t3
|
||||
ON m.name = t3.name
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from mysql.connector.aio import MySQLConnection, connect
|
||||
|
||||
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
|
||||
|
||||
with open(Path(__file__).parent / "metadata_sql" / "mysql.sql", encoding="utf-8") as f:
|
||||
METADATA_SQL = f.read()
|
||||
|
||||
version_p = re.compile(r"\b(v)?(\d+)\.(\d+)(?:\.(\d+))?(?:-([a-zA-Z0-9]+))?\b")
|
||||
|
||||
|
||||
class Mysql(DataSourceBase):
|
||||
def __init__(self, host: str, port: int, username: str, password: str, db_name: str):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.db_name = db_name
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
return "mysql"
|
||||
|
||||
@property
|
||||
def string_types(self) -> set[str]:
|
||||
return {
|
||||
"CHAR",
|
||||
"VARCHAR",
|
||||
"TEXT",
|
||||
"TINYTEXT",
|
||||
"MEDIUMTEXT",
|
||||
"LONGTEXT",
|
||||
}
|
||||
|
||||
@property
|
||||
def json_array_agg_func(self) -> str:
|
||||
return "JSON_ARRAYAGG"
|
||||
|
||||
async def acreate_connection(self) -> MySQLConnection:
|
||||
conn = await connect(
|
||||
host=self.host, port=self.port, user=self.username, password=self.password, database=self.db_name
|
||||
)
|
||||
return cast(MySQLConnection, conn)
|
||||
|
||||
async def aget_server_version(self) -> DBServerVersion:
|
||||
sql = "SELECT VERSION()"
|
||||
_, rows, _ = await self.aexecute_raw_sql(sql)
|
||||
version = cast(str, rows[0][0])
|
||||
match = version_p.match(version)
|
||||
if match:
|
||||
has_v, major, minor, patch, suffix = match.groups()
|
||||
server_version = DBServerVersion(major=int(major), minor=int(minor))
|
||||
if patch:
|
||||
server_version.patch = int(patch)
|
||||
return server_version
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
|
||||
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
|
||||
if not sql:
|
||||
return [], [], SQLException(error_type=SQLError.EmptySQL, msg="SQL语句不能为空")
|
||||
|
||||
try:
|
||||
conn = await self.acreate_connection()
|
||||
except Exception as e:
|
||||
return [], [], SQLException(error_type=SQLError.DBError, msg=e.args[1])
|
||||
|
||||
try:
|
||||
cursor = await conn.cursor()
|
||||
await cursor.execute(sql)
|
||||
res = await cursor.fetchall()
|
||||
if cursor.description:
|
||||
cols = [i[0] for i in cursor.description]
|
||||
else:
|
||||
cols = []
|
||||
await cursor.close()
|
||||
await conn.close()
|
||||
return cols, res, None
|
||||
except Exception as e:
|
||||
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=e.args[1], code=e.args[0])
|
||||
|
||||
async def aget_metadata(self) -> DatabaseMetadata:
|
||||
sql = METADATA_SQL.format(db_name=self.db_name)
|
||||
cols, rows, err = await self.aexecute_raw_sql(sql)
|
||||
if err:
|
||||
raise ConnectionError(err.msg)
|
||||
if not rows or not rows[0][1]:
|
||||
return DatabaseMetadata(name=self.db_name)
|
||||
metadata = DatabaseMetadata.model_validate({
|
||||
"name": rows[0][0],
|
||||
"tables": orjson.loads(cast(str, rows[0][1])),
|
||||
})
|
||||
return metadata
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import asyncpg
|
||||
import orjson
|
||||
|
||||
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
|
||||
|
||||
with open(Path(__file__).parent / "metadata_sql" / "postgres.sql", encoding="utf-8") as f:
|
||||
METADATA_SQL = f.read()
|
||||
|
||||
|
||||
version_pattern = re.compile(
|
||||
r".*(?:PostgreSQL|EnterpriseDB) "
|
||||
r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?"
|
||||
)
|
||||
|
||||
|
||||
class Postgres(DataSourceBase):
|
||||
"""
|
||||
postgres不支持跨数据库查询
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int, username: str, password: str, db_name: str, schema: str = "public"):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.db_name = db_name
|
||||
self.schema = schema
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
return "postgres"
|
||||
|
||||
@property
|
||||
def string_types(self) -> set[str]:
|
||||
return {"CHARACTER VARYING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"}
|
||||
|
||||
@property
|
||||
def json_array_agg_func(self) -> str:
|
||||
return "JSON_AGG"
|
||||
|
||||
async def acreate_connection(self) -> asyncpg.Connection:
|
||||
conn = await asyncpg.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
database=self.db_name,
|
||||
)
|
||||
return conn
|
||||
|
||||
async def aget_server_version(self) -> DBServerVersion:
|
||||
sql = "SELECT VERSION() as v"
|
||||
_, rows, _ = await self.aexecute_raw_sql(sql)
|
||||
version_str = cast(str, rows[0][0])
|
||||
m = version_pattern.match(version_str)
|
||||
if not m:
|
||||
raise AssertionError("Could not determine version from string '%s'" % version_str)
|
||||
version = [int(x) for x in m.group(1, 2, 3) if x is not None]
|
||||
server_version = DBServerVersion(major=version[0], minor=version[1])
|
||||
if len(version) > 2:
|
||||
server_version.patch = version[2]
|
||||
return server_version
|
||||
|
||||
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
|
||||
if not sql:
|
||||
return [], [], SQLException(error_type=SQLError.EmptySQL, msg="SQL语句不能为空")
|
||||
|
||||
try:
|
||||
conn = await self.acreate_connection()
|
||||
except Exception as e:
|
||||
return [], [], SQLException(error_type=SQLError.DBError, msg=str(e))
|
||||
|
||||
try:
|
||||
res: list[asyncpg.Record] = await conn.fetch(sql)
|
||||
return list(res[0].keys()), [tuple(x.values()) for x in res], None
|
||||
except Exception as e:
|
||||
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))
|
||||
|
||||
async def aget_metadata(self) -> DatabaseMetadata:
|
||||
sql = METADATA_SQL.format(schema=self.schema)
|
||||
cols, rows, err = await self.aexecute_raw_sql(sql)
|
||||
if err:
|
||||
raise ConnectionError(err.msg)
|
||||
if not rows or not rows[0][1]:
|
||||
return DatabaseMetadata(name=self.db_name)
|
||||
metadata = DatabaseMetadata.model_validate({
|
||||
"name": self.db_name,
|
||||
"tables": orjson.loads(cast(str, rows[0][1])),
|
||||
})
|
||||
return metadata
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import aiosqlite
|
||||
import orjson
|
||||
|
||||
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
|
||||
|
||||
with open(Path(__file__).parent / "metadata_sql" / "sqlite.sql", encoding="utf-8") as f:
|
||||
METADATA_SQL = f.read()
|
||||
|
||||
|
||||
class Sqlite(DataSourceBase):
|
||||
def __init__(self, db_path: str | Path):
|
||||
if isinstance(db_path, str):
|
||||
db_path = Path(db_path)
|
||||
self.db_path = db_path
|
||||
self.db_name = db_path.stem
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
return "sqlite"
|
||||
|
||||
@property
|
||||
def string_types(self) -> set[str]:
|
||||
return {"TEXT"}
|
||||
|
||||
@property
|
||||
def json_array_agg_func(self) -> str:
|
||||
return "JSON_GROUP_ARRAY"
|
||||
|
||||
async def conn_test(self) -> bool:
|
||||
if self.db_path.is_file():
|
||||
return True
|
||||
return False
|
||||
|
||||
async def aget_server_version(self) -> DBServerVersion:
|
||||
version = [int(i) for i in aiosqlite.sqlite_version.split(".") if i.isdigit()]
|
||||
server_version = DBServerVersion(major=version[0], minor=version[1])
|
||||
if len(version) > 2:
|
||||
server_version.patch = version[2]
|
||||
return server_version
|
||||
|
||||
async def aget_metadata(self) -> DatabaseMetadata:
|
||||
_, rows, err = await self.aexecute_raw_sql(METADATA_SQL)
|
||||
if err:
|
||||
raise ConnectionError(err.msg)
|
||||
if not rows or not rows[0][1]:
|
||||
return DatabaseMetadata(name=self.db_name)
|
||||
metadata = DatabaseMetadata.model_validate({
|
||||
"name": self.db_name,
|
||||
"tables": orjson.loads(cast(str, rows[0][0])),
|
||||
})
|
||||
return metadata
|
||||
|
||||
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute(sql)
|
||||
res = await cursor.fetchall()
|
||||
return [i[0] for i in cursor.description], cast(list[tuple[Any, ...]], res), None
|
||||
except Exception as e:
|
||||
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))
|
||||
|
|
@ -0,0 +1 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..data_source.base import DBTable
|
||||
from ..utils import JOIN_CHAR, calculate_metrics, parse_real_cols
|
||||
|
||||
|
||||
def calculate_retrieval_metrics(
|
||||
tables: list[DBTable], gold_sql: str, dialect="mysql", case_insensitive: bool = True
|
||||
) -> tuple[float, float, float]:
|
||||
gold_cols = parse_real_cols(gold_sql, dialect, case_insensitive)
|
||||
recall_cols = set()
|
||||
for table in tables:
|
||||
for cn in table.columns.keys():
|
||||
col_name = f"{table.name}{JOIN_CHAR}{cn}"
|
||||
if case_insensitive:
|
||||
col_name = col_name.lower()
|
||||
recall_cols.add(col_name)
|
||||
|
||||
tp = len(gold_cols & recall_cols)
|
||||
fp = len(recall_cols - gold_cols)
|
||||
fn = len(gold_cols - recall_cols)
|
||||
return calculate_metrics(tp, fp, fn)
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..utils import calculate_metrics
|
||||
|
||||
|
||||
def calculate_ex(predicted_res: list[tuple[Any, ...]], ground_truth_res: list[tuple[Any, ...]]) -> int:
|
||||
res = 0
|
||||
if set(predicted_res) == set(ground_truth_res):
|
||||
res = 1
|
||||
return res
|
||||
|
||||
|
||||
def calculate_row_match(
|
||||
predicted_row: tuple[Any, ...], ground_truth_row: tuple[Any, ...]
|
||||
) -> tuple[float, float, float]:
|
||||
"""
|
||||
Calculate the matching percentage for a single row.
|
||||
|
||||
Args:
|
||||
predicted_row (tuple): The predicted row values.
|
||||
ground_truth_row (tuple): The actual row values from ground truth.
|
||||
|
||||
Returns:
|
||||
float: The match percentage (0 to 1 scale).
|
||||
"""
|
||||
total_columns = len(ground_truth_row)
|
||||
matches = 0
|
||||
element_in_pred_only = 0
|
||||
element_in_truth_only = 0
|
||||
for pred_val in predicted_row:
|
||||
if pred_val in ground_truth_row:
|
||||
matches += 1
|
||||
else:
|
||||
element_in_pred_only += 1
|
||||
for truth_val in ground_truth_row:
|
||||
if truth_val not in predicted_row:
|
||||
element_in_truth_only += 1
|
||||
match_percentage = matches / total_columns
|
||||
pred_only_percentage = element_in_pred_only / total_columns
|
||||
truth_only_percentage = element_in_truth_only / total_columns
|
||||
return match_percentage, pred_only_percentage, truth_only_percentage
|
||||
|
||||
|
||||
def calculate_f1(
|
||||
predicted: list[tuple[Any, ...]], ground_truth: list[tuple[Any, ...]]
|
||||
) -> tuple[float, float, float]:
|
||||
"""
|
||||
Calculate the F1 score based on sets of predicted results and ground truth results,
|
||||
where each element (tuple) represents a row from the database with multiple columns.
|
||||
|
||||
Args:
|
||||
predicted (set of tuples): Predicted results from SQL query.
|
||||
ground_truth (set of tuples): Actual results expected (ground truth).
|
||||
|
||||
Returns:
|
||||
float: The calculated F1 score.
|
||||
"""
|
||||
# if both predicted and ground_truth are empty, return 1.0 for f1_score
|
||||
if not predicted and not ground_truth:
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
# Calculate matching scores for each possible pair
|
||||
match_scores: list[float] = []
|
||||
pred_only_scores: list[float] = []
|
||||
truth_only_scores: list[float] = []
|
||||
for i, gt_row in enumerate(ground_truth):
|
||||
# rows only in the ground truth results
|
||||
if i >= len(predicted):
|
||||
match_scores.append(0)
|
||||
truth_only_scores.append(1)
|
||||
continue
|
||||
pred_row = predicted[i]
|
||||
match_score, pred_only_score, truth_only_score = calculate_row_match(pred_row, gt_row)
|
||||
match_scores.append(match_score)
|
||||
pred_only_scores.append(pred_only_score)
|
||||
truth_only_scores.append(truth_only_score)
|
||||
|
||||
# rows only in the predicted results
|
||||
for i in range(len(predicted) - len(ground_truth)):
|
||||
match_scores.append(0)
|
||||
pred_only_scores.append(1)
|
||||
truth_only_scores.append(0)
|
||||
|
||||
tp = sum(match_scores)
|
||||
fp = sum(pred_only_scores)
|
||||
fn = sum(truth_only_scores)
|
||||
|
||||
precision, recall, f1_score = calculate_metrics(tp, fp, fn)
|
||||
return precision, recall, f1_score
|
||||
|
||||
def calculate_ves(
|
||||
predicted_res: list[tuple[Any, ...]],
|
||||
ground_truth_res: list[tuple[Any, ...]],
|
||||
pred_cost_time:float,
|
||||
ground_truth_cost_time:float
|
||||
):
|
||||
time_ratio = 0
|
||||
if set(predicted_res) == set(ground_truth_res):
|
||||
time_ratio = ground_truth_cost_time/pred_cost_time
|
||||
|
||||
if time_ratio == 0:
|
||||
reward = 0
|
||||
elif time_ratio >= 2:
|
||||
reward = 1.25
|
||||
elif 1 <= time_ratio < 2:
|
||||
reward = 1
|
||||
elif 0.5 <= time_ratio < 1:
|
||||
reward = 0.75
|
||||
elif 0.25 <= time_ratio < 0.5:
|
||||
reward = 0.5
|
||||
else:
|
||||
reward = 0.25
|
||||
return reward
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
from json import JSONDecodeError
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
|
||||
from ..utils import cal_tokens, get_beijing_time, text2json
|
||||
|
||||
prompt_prefix = """
|
||||
You are a helpful data analyst who is great at thinking deeply and reasoning about the user's question and the database schema.
|
||||
|
||||
## 1. Database info
|
||||
{db_info}
|
||||
|
||||
## 2. Context Information
|
||||
- Current Time: {current_time}. It is only used when calculation or conversion is required according to the current system time.
|
||||
|
||||
## 3. Constraints
|
||||
- Generate an optimized SQL query that directly answers the user's question.
|
||||
- The SQL query must be fully formed, valid, and executable.
|
||||
- Do NOT include any explanations, markdown formatting, or comments.
|
||||
- If you want the maximum or minimum value, do not limit it to 1.
|
||||
- When you need to calculate the proportion or other indicators, please use double precision.
|
||||
|
||||
## 4. Evidence
|
||||
{evidence}
|
||||
|
||||
## 5. QUESTION
|
||||
User's Question: {user_query}
|
||||
""" # noqa: E501
|
||||
|
||||
generation_prompt = """
|
||||
Respond in the following JSON format:
|
||||
- If the user query is not related to the database, answer with empty string.
|
||||
{{
|
||||
"answer": ""
|
||||
}}
|
||||
- If you can answer the questions based on the database schema and don't need to generate SQL, generate the answers directly.
|
||||
{{
|
||||
"answer": "answer based on database schema"
|
||||
}}
|
||||
- If you need to answer the question by querying the database, please generate SQL, select only the necessary fields needed to answer the question, without any missing or extra information.
|
||||
- Prefer using aggregate functions (such as COUNT, SUM, etc.) in SQL query and avoid returning redundant data for post-processing.
|
||||
{{
|
||||
"sql": "Generated SQL query here"
|
||||
}}
|
||||
""" # noqa: E501
|
||||
|
||||
correction_prompt = """
|
||||
There is a SQL, but an error was reported after execution. Analyze why the given SQL query does not produce the correct results, identify the issues, and provide a corrected SQL query that properly answers the user's request.
|
||||
|
||||
- Current SQL Query: {current_sql}
|
||||
- Error: {error}
|
||||
|
||||
**What You Need to Do:**
|
||||
1. **Analyze:** Explain why the current SQL query fails to produce the correct results.
|
||||
2. **Provide a Corrected SQL Query:** Write a revised query that returns the correct results.
|
||||
|
||||
Respond in the following JSON format:
|
||||
{{
|
||||
"sql": "The corrected SQL query should be placed here."
|
||||
}}
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
async def arun(
|
||||
query: str,
|
||||
llm: BaseChatModel,
|
||||
db_info: str,
|
||||
evidence: str = "",
|
||||
error_sql: str | None = None,
|
||||
error_msg: str | None = None,
|
||||
) -> tuple[str | None, str, int, int, str | None]:
|
||||
answer = None
|
||||
sql = ""
|
||||
error = None
|
||||
data = {
|
||||
"db_info": db_info,
|
||||
"user_query": query,
|
||||
"current_time": await asyncio.to_thread(get_beijing_time),
|
||||
"evidence": evidence or "",
|
||||
}
|
||||
input_tokens, output_tokens = 0, 0
|
||||
if not error_sql:
|
||||
prompt = prompt_prefix + generation_prompt
|
||||
else:
|
||||
assert error_msg, "error_msg is required when you need to correct sql"
|
||||
data["current_sql"] = error_sql
|
||||
data["error"] = error_msg
|
||||
prompt = prompt_prefix + correction_prompt
|
||||
try:
|
||||
content = ""
|
||||
async for chunk in llm.astream([HumanMessage(prompt.format(**data))]):
|
||||
msg = cast(AIMessageChunk, chunk)
|
||||
content += cast(str, msg.content)
|
||||
input_tokens, output_tokens = await asyncio.to_thread(cal_tokens, msg)
|
||||
except Exception as e:
|
||||
raise ValueError(f"调用LLM失败:{e}")
|
||||
|
||||
if content.startswith("SELECT"):
|
||||
sql = content
|
||||
elif content.find("```sql") != -1:
|
||||
sql = content.split("```sql")[1].split("```")[0]
|
||||
else:
|
||||
try:
|
||||
result = await asyncio.to_thread(text2json, content)
|
||||
if "answer" in result:
|
||||
answer = result.get("answer") or ""
|
||||
else:
|
||||
sql = result.get("sql") or ""
|
||||
except JSONDecodeError:
|
||||
error = "Incorrect json format"
|
||||
|
||||
return answer, sql, input_tokens, output_tokens, error
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .check import SQLCheck
|
||||
from .optimization import SQLOptimization
|
||||
|
||||
__all__ = ["SQLCheck", "SQLOptimization"]
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
|
||||
|
||||
class SQLCheck:
|
||||
def __init__(self, dialect: str):
|
||||
self.dialect = dialect
|
||||
|
||||
def is_query(self, sql_or_exp: str | exp.Expression) -> bool:
|
||||
"""判断是否是查询语句"""
|
||||
if isinstance(sql_or_exp, exp.Expression):
|
||||
expression = sql_or_exp
|
||||
else:
|
||||
expression = parse_one(sql_or_exp, dialect=self.dialect)
|
||||
return isinstance(expression, exp.Query)
|
||||
|
||||
def syntax_valid(self, sql: str) -> exp.Expression:
|
||||
"""基本语法验证"""
|
||||
return parse_one(sql, dialect=self.dialect)
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlglot import ParseError, exp, parse_one
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
|
||||
|
||||
class SQLOptimization:
|
||||
def __init__(self, dialect: str, db_major_version: int):
|
||||
self.dialect = dialect
|
||||
self.db_major_version = db_major_version
|
||||
|
||||
@staticmethod
|
||||
def cte_to_subquery(expression: exp.Expression) -> exp.Expression:
|
||||
# 收集所有 CTE
|
||||
cte_names = {}
|
||||
for cte in expression.find_all(exp.CTE):
|
||||
alias = cte.alias
|
||||
query = cte.this
|
||||
cte_names[alias] = query
|
||||
# 移除原始 CTE 节点
|
||||
if cte.parent:
|
||||
cte.parent.pop()
|
||||
|
||||
if not cte_names:
|
||||
return expression
|
||||
|
||||
# 替换所有对 CTE 的引用为子查询
|
||||
def replace_cte_with_subquery(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Table) and node.name in cte_names:
|
||||
subquery = cte_names[node.name]
|
||||
return exp.Subquery(this=subquery.copy(), alias=node.alias or node.name)
|
||||
|
||||
return node
|
||||
|
||||
return expression.transform(replace_cte_with_subquery)
|
||||
|
||||
@staticmethod
|
||||
def optimize_in_limit_subquery(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Avoid 'LIMIT & IN/ALL/ANY/SOME subquery'
|
||||
"""
|
||||
for in_expr in expression.find_all(exp.In):
|
||||
subquery = in_expr.args.get("query")
|
||||
if not subquery:
|
||||
continue
|
||||
if subquery.this.find(exp.Limit):
|
||||
t = subquery.this.args.get("from").this
|
||||
if t.args.get("alias"):
|
||||
alias = t.args.get("alias").this.this
|
||||
else:
|
||||
alias = exp.TableAlias(this=exp.Identifier(this="t"))
|
||||
derived_table = exp.Subquery(this=subquery.this.copy(), alias=alias)
|
||||
# 构建新的 SELECT t.id FROM (subquery) AS t
|
||||
new_subquery_select = exp.select(*subquery.this.expressions).from_(derived_table)
|
||||
# 替换 IN 的子查询
|
||||
in_expr.set("query", exp.Subquery(this=new_subquery_select))
|
||||
|
||||
return expression
|
||||
|
||||
@staticmethod
|
||||
def fix_missing_group_by_when_agg_func(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
case2:SELECT a, COUNT(b) FROM x --> SELECT a, b FROM x GROUP BY a
|
||||
case3:SELECT a FROM x ORDER BY MAX(b) --> SELECT a FROM x GROUP BY a ORDER BY MAX(b)
|
||||
"""
|
||||
for select_expr in expression.find_all(exp.Select):
|
||||
select_agg = False
|
||||
not_agg_query_cols = dict()
|
||||
group_cols = dict()
|
||||
order_by_agg = False
|
||||
for col in select_expr.expressions:
|
||||
if isinstance(col, exp.Column):
|
||||
not_agg_query_cols[col.this.this] = col
|
||||
elif isinstance(col, exp.AggFunc):
|
||||
select_agg = True
|
||||
elif isinstance(col, exp.Alias):
|
||||
if isinstance(col.this, exp.Column):
|
||||
not_agg_query_cols[col.this.this.this] = col.this
|
||||
elif isinstance(col.this, exp.AggFunc):
|
||||
select_agg = True
|
||||
|
||||
if expression.args.get("group"):
|
||||
for col in expression.args["group"].expressions:
|
||||
group_cols[col.this.this] = col
|
||||
if expression.args.get("order"):
|
||||
for order_col in expression.args["order"].expressions:
|
||||
if isinstance(order_col.this, exp.AggFunc):
|
||||
order_by_agg = True
|
||||
|
||||
if group_cols or (select_agg and not_agg_query_cols) or order_by_agg:
|
||||
for col in not_agg_query_cols:
|
||||
if col not in group_cols:
|
||||
group_cols[col] = not_agg_query_cols[col]
|
||||
|
||||
if group_cols:
|
||||
select_expr.set("group", exp.Group(expressions=group_cols.values()))
|
||||
|
||||
return expression
|
||||
|
||||
@staticmethod
|
||||
def set_limit(expression: exp.Expression, result_num_limit: int):
|
||||
if expression.args.get("limit"):
|
||||
limit_exp = cast(exp.Limit, expression.args.get("limit"))
|
||||
limit = min(result_num_limit, int(limit_exp.expression.this))
|
||||
else:
|
||||
limit = result_num_limit
|
||||
expression.set("limit", exp.Limit(expression=exp.Literal.number(limit)))
|
||||
|
||||
def arun(
|
||||
self,
|
||||
sql_or_exp: str | exp.Expression,
|
||||
schema_type: dict[str, dict[str, Any]] | None = None,
|
||||
result_num_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
sql_or_exp:
|
||||
schema_type: db schema type, a mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
result_num_limit: int
|
||||
Returns: str, optimized sql
|
||||
"""
|
||||
if isinstance(sql_or_exp, exp.Expression):
|
||||
expression = sql_or_exp
|
||||
else:
|
||||
try:
|
||||
expression = parse_one(sql_or_exp, dialect=self.dialect)
|
||||
except ParseError:
|
||||
expression = parse_one(sql_or_exp)
|
||||
|
||||
if result_num_limit and result_num_limit > 0:
|
||||
self.set_limit(expression, result_num_limit)
|
||||
|
||||
rules = list(RULES)
|
||||
if self.dialect == "mysql" and self.db_major_version < 8:
|
||||
# mysql 8.0以上才支持 CTE,否则只能用子查询
|
||||
rules.remove(eliminate_subqueries)
|
||||
rules.append(self.cte_to_subquery)
|
||||
|
||||
# 优化 in limit 子查询
|
||||
rules.append(self.optimize_in_limit_subquery)
|
||||
# 当聚合查询时,修复sql中 GROUP BY 缺少的字段
|
||||
rules.append(self.fix_missing_group_by_when_agg_func)
|
||||
expression = optimize(
|
||||
expression,
|
||||
schema=schema_type,
|
||||
dialect=self.dialect,
|
||||
rules=rules, # type: ignore
|
||||
identify=False,
|
||||
)
|
||||
return expression.sql(self.dialect)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from ..utils import cal_tokens, text2json
|
||||
|
||||
prompt = """
|
||||
# Database info
|
||||
{db_info}
|
||||
|
||||
# evidence:
|
||||
{evidence}
|
||||
|
||||
# Original user query:
|
||||
{query}
|
||||
|
||||
# Generated SQL:
|
||||
{generated_sql}
|
||||
|
||||
# SQL query results:
|
||||
{sql_result}
|
||||
|
||||
# Constraints:
|
||||
- 1. Evaluate the relevance and quality of the SQL query results to the original user query.
|
||||
- 2. If SQL results is empty or relevant, according to the above information directly answer the user's question, do not output "according to", "based on" and other redundant phrases.
|
||||
Respond in the following JSON format:
|
||||
{{
|
||||
"final_answer": "Return a full natural language answer"
|
||||
}}
|
||||
- 3. If SQL results is not empty and not relevant, explain of why the SQL query is not relevant.
|
||||
Respond in the following JSON format:
|
||||
{{
|
||||
"explanation": "explanation of why the SQL query is not relevant"
|
||||
}}
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
async def arun(
|
||||
query: str,
|
||||
llm: BaseChatModel,
|
||||
db_info: str,
|
||||
sql: str,
|
||||
res_rows: list[tuple[Any, ...]],
|
||||
evidence: str | None = None,
|
||||
) -> tuple[str, int, int, str | None]:
|
||||
answer = ""
|
||||
error = None
|
||||
input_tokens, output_tokens = 0, 0
|
||||
human_msg = prompt.format(
|
||||
db_info=db_info,
|
||||
query=query,
|
||||
generated_sql=sql,
|
||||
sql_result=res_rows,
|
||||
evidence=evidence or "",
|
||||
)
|
||||
try:
|
||||
content = ""
|
||||
async for chunk in llm.astream([HumanMessage(human_msg)]):
|
||||
msg = cast(AIMessage, chunk)
|
||||
content += cast(str, msg.content)
|
||||
input_tokens, output_tokens = await asyncio.to_thread(cal_tokens, msg)
|
||||
except Exception as e:
|
||||
raise ValueError(f"调用LLM失败:{e}")
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(text2json, cast(str, content))
|
||||
if "final_answer" in result:
|
||||
answer = result["final_answer"] or ""
|
||||
else:
|
||||
error = result.get("explanation", "")
|
||||
except JSONDecodeError:
|
||||
error = "Incorrect json format"
|
||||
|
||||
return answer, input_tokens, output_tokens, error
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from decimal import Decimal
|
||||
from typing import Any, Callable
|
||||
|
||||
import json_repair as json
|
||||
from dateutil.parser import parse as date_parse
|
||||
from langchain_core.messages import AIMessage
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot.optimizer.qualify import qualify
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
JOIN_CHAR = "."
|
||||
|
||||
email_re = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
|
||||
def text2json(text: str) -> dict[Any, Any]:
|
||||
res = json.loads(text)
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
return dict()
|
||||
|
||||
|
||||
def get_beijing_time(fmt: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
utc_now = datetime.datetime.now(datetime.UTC)
|
||||
beijing_now = utc_now + datetime.timedelta(hours=8)
|
||||
formatted_time = beijing_now.strftime(fmt)
|
||||
return formatted_time
|
||||
|
||||
|
||||
def cal_tokens(msg: AIMessage) -> tuple[int, int]:
|
||||
if msg.usage_metadata:
|
||||
return msg.usage_metadata["input_tokens"], msg.usage_metadata["output_tokens"]
|
||||
elif msg.response_metadata.get("token_usage", {}).get("input_tokens"):
|
||||
token_usage = msg.response_metadata.get("token_usage", {})
|
||||
input_tokens = token_usage.get("input_tokens", 0)
|
||||
output_tokens = token_usage.get("output_tokens", 0)
|
||||
return input_tokens, output_tokens
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def convert_value2str(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, Decimal):
|
||||
return str(float(value))
|
||||
elif value is None:
|
||||
return ""
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def is_valid_uuid(s: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(value: Any) -> bool:
|
||||
try:
|
||||
float(value)
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def is_date(value: Any) -> bool:
|
||||
try:
|
||||
date_parse(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_email(value: Any) -> bool:
|
||||
if email_re.match(value):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def truncate_text(content: Any, *, max_length: int, suffix: str = "...") -> Any:
|
||||
if not isinstance(content, str) or max_length <= 0:
|
||||
return content
|
||||
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
if max_length <= len(suffix):
|
||||
return content[:max_length]
|
||||
|
||||
# 确保截断后的文本不会超过最大长度,且不会在单词中间截断。
|
||||
return content[: max_length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
||||
|
||||
|
||||
def truncate_text_by_byte(content: Any, *, max_length: int, suffix: str = "...", encoding: str = "utf-8") -> Any:
|
||||
if not isinstance(content, str) or max_length <= 0:
|
||||
return content
|
||||
|
||||
encoded = content.encode(encoding)
|
||||
if len(encoded) <= max_length:
|
||||
return content
|
||||
|
||||
suffix_bytes = suffix.encode(encoding)
|
||||
available_bytes = max_length - len(suffix_bytes)
|
||||
|
||||
if available_bytes <= 0:
|
||||
for i in range(max_length, 0, -1):
|
||||
try:
|
||||
return encoded[:i].decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return ""
|
||||
|
||||
# 尝试找到合适的截断位置,保证解码安全
|
||||
for i in range(available_bytes, 0, -1):
|
||||
try:
|
||||
truncated_part = encoded[:i].decode(encoding)
|
||||
return truncated_part + suffix
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
return suffix_bytes[:max_length].decode(encoding, errors="ignore")
|
||||
|
||||
|
||||
async def async_parallel_exec(func: Callable[[Any], Any], data: list[Any], concurrency: int | None = None) -> list[Any]:
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
async_func = functools.partial(asyncio.to_thread, func) # type: ignore
|
||||
else:
|
||||
async_func = func # type: ignore
|
||||
|
||||
if len(data) == 0:
|
||||
return []
|
||||
|
||||
if concurrency is None:
|
||||
concurrency = len(data)
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def worker(*args, **kwargs) -> Any:
|
||||
async with semaphore:
|
||||
return await async_func(*args, **kwargs) # type: ignore
|
||||
|
||||
t_list = []
|
||||
for i in range(len(data)):
|
||||
if isinstance(data[i], (list, tuple)):
|
||||
t_list.append(worker(*data[i]))
|
||||
elif isinstance(data[i], dict):
|
||||
t_list.append(worker(**data[i]))
|
||||
else:
|
||||
t_list.append(worker(data[i]))
|
||||
|
||||
return await asyncio.gather(*t_list)
|
||||
|
||||
|
||||
def parallel_exec(
|
||||
func: Callable[[Any], Any], data: list[dict[Any, Any] | tuple[Any] | list[Any]], concurrency: int | None = None
|
||||
) -> list[Any]:
|
||||
if len(data) == 0:
|
||||
return []
|
||||
|
||||
t_list = []
|
||||
if concurrency is None:
|
||||
concurrency = min(len(data), 32, (os.cpu_count() or 1) + 4)
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
||||
for i in range(len(data)):
|
||||
if isinstance(data[i], (list, tuple)):
|
||||
t_list.append(pool.submit(func, *data[i]))
|
||||
elif isinstance(data[i], dict):
|
||||
t_list.append(pool.submit(func, **data[i])) # type: ignore
|
||||
else:
|
||||
t_list.append(pool.submit(func, data[i]))
|
||||
|
||||
return [t.result() for t in t_list]
|
||||
|
||||
|
||||
def exec_async_func(func: Callable[[Any], Any], *args, **kwargs) -> Any:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
return loop.run_until_complete(func(*args, **kwargs))
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_col_name_from_scope(
|
||||
scope: Scope, real_cols: set[str], cte_names: set[str], case_insensitive: bool = True
|
||||
) -> None:
|
||||
if case_insensitive:
|
||||
scope.sources = {k.lower(): v for k, v in scope.sources.items()}
|
||||
|
||||
for col in scope.columns:
|
||||
if col.table:
|
||||
tn = col.table
|
||||
if case_insensitive:
|
||||
tn = tn.lower()
|
||||
if tn in scope.sources:
|
||||
source = scope.sources[tn]
|
||||
if isinstance(source, exp.Table) and tn not in cte_names:
|
||||
src_tn = source.name
|
||||
col_name = f"{src_tn}{JOIN_CHAR}{col.name}"
|
||||
if case_insensitive:
|
||||
col_name = col_name.lower()
|
||||
real_cols.add(col_name)
|
||||
elif not col.table and scope.tables:
|
||||
tn = scope.tables[0].name
|
||||
if tn not in cte_names:
|
||||
col_name = f"{tn}{JOIN_CHAR}{col.name}"
|
||||
if case_insensitive:
|
||||
col_name = col_name.lower()
|
||||
real_cols.add(col_name)
|
||||
|
||||
|
||||
def parse_real_cols(sql: str, dialect="mysql", case_insensitive: bool = True) -> set[str]:
|
||||
try:
|
||||
parsed = parse_one(sql, dialect=dialect)
|
||||
parsed = qualify(parsed, dialect=dialect)
|
||||
cte_names = set()
|
||||
for cte in parsed.find_all(exp.CTE):
|
||||
cte_names.add(cte.alias)
|
||||
scopes = traverse_scope(parsed)
|
||||
real_cols: set[str] = set()
|
||||
for scope in scopes:
|
||||
parse_col_name_from_scope(scope, real_cols, cte_names, case_insensitive)
|
||||
return real_cols
|
||||
except Exception as e:
|
||||
print(f"Error when parsing, error: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def calculate_metrics(tp: float, fp: float, fn: float) -> tuple[float, float, float]:
|
||||
precision = tp / (tp + fp) if tp + fp > 0 else 0
|
||||
recall = tp / (tp + fn) if tp + fn > 0 else 0
|
||||
f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
|
||||
return precision, recall, f1_score
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from dative.api import v1
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(v1.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=3000)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
# 将 src 目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
|
||||
# 注: pytest.fixture 与 unittest.TestCase 类 混合使用不兼容
|
||||
if os.name != "nt":
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_teardown():
|
||||
# Setup logic
|
||||
|
||||
yield # 运行测试
|
||||
|
||||
# Teardown logic
|
||||
|
||||
|
||||
def wait_until(condition_function: Callable, timeout=10):
|
||||
start_time = time.time()
|
||||
while not condition_function():
|
||||
if timeout is not None and time.time() - start_time > timeout:
|
||||
raise TimeoutError("Condition not met within the specified timeout")
|
||||
time.sleep(0.01) # 短暂休眠以减少 CPU 使用
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.dative.core.evaluation.schema_retrieval import calculate_f1_score
|
||||
|
||||
|
||||
def test_calculate_f1_score_perfect_match():
|
||||
"""测试检索结果与gold SQL完全匹配的情况"""
|
||||
# 创建模拟的DBTable对象
|
||||
table1 = Mock()
|
||||
table1.name = "table1"
|
||||
table1.columns = {"a": None, "b": None} # columns.keys()返回['a', 'b']
|
||||
|
||||
table2 = Mock()
|
||||
table2.name = "table2"
|
||||
table2.columns = {"c": None}
|
||||
|
||||
retrieval_res = [table1, table2]
|
||||
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2"
|
||||
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# 完全匹配时,precision, recall, f1都应该是1.0
|
||||
assert precision == 1.0
|
||||
assert recall == 1.0
|
||||
assert f1 == 1.0
|
||||
|
||||
|
||||
def test_calculate_f1_score_partial_match():
|
||||
"""测试检索结果与gold SQL部分匹配的情况"""
|
||||
# 创建模拟的DBTable对象
|
||||
table1 = Mock()
|
||||
table1.name = "table1"
|
||||
table1.columns = {"a": None, "b": None}
|
||||
|
||||
retrieval_res = [table1] # 只检索到table1的列
|
||||
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2" # gold SQL需要table1和table2的列
|
||||
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# 精确度应该是1.0 (检索到的都正确),召回率应该是2/3 (只找到2个正确的,共3个)
|
||||
assert precision == 1.0
|
||||
assert recall == pytest.approx(2 / 3)
|
||||
assert f1 == pytest.approx(2 * (1.0 * 2 / 3) / (1.0 + 2 / 3))
|
||||
|
||||
|
||||
def test_calculate_f1_score_no_match():
|
||||
"""测试检索结果与gold SQL完全不匹配的情况"""
|
||||
# 创建模拟的DBTable对象
|
||||
table1 = Mock()
|
||||
table1.name = "table3" # 与gold SQL中的表名不同
|
||||
table1.columns = {"d": None, "e": None}
|
||||
|
||||
retrieval_res = [table1]
|
||||
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2"
|
||||
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# 没有匹配时,precision, recall, f1都应该是0.0
|
||||
assert precision == 0.0
|
||||
assert recall == 0.0
|
||||
assert f1 == 0.0
|
||||
|
||||
|
||||
def test_calculate_f1_score_case_insensitive():
|
||||
"""测试大小写不敏感的情况"""
|
||||
table1 = Mock()
|
||||
table1.name = "Table1" # 大小写与gold SQL不同
|
||||
table1.columns = {"A": None, "B": None}
|
||||
|
||||
retrieval_res = [table1]
|
||||
gold_sql = "SELECT table1.a, table1.b FROM table1"
|
||||
|
||||
# 默认大小写不敏感
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# 应该匹配,因为大小写不敏感
|
||||
assert precision == 1.0
|
||||
assert recall == 1.0
|
||||
assert f1 == 1.0
|
||||
|
||||
|
||||
def test_calculate_f1_score_case_sensitive():
|
||||
"""测试大小写敏感的情况"""
|
||||
table1 = Mock()
|
||||
table1.name = "Table1" # 大小写与gold SQL不同
|
||||
table1.columns = {"A": None, "B": None}
|
||||
|
||||
retrieval_res = [table1]
|
||||
gold_sql = "SELECT table1.a, table1.b FROM table1"
|
||||
|
||||
# 设置大小写敏感
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql, case_insensitive=False)
|
||||
|
||||
# 应该不匹配,因为大小写敏感
|
||||
assert precision == 0.0
|
||||
assert recall == 0.0
|
||||
assert f1 == 0.0
|
||||
|
||||
|
||||
def test_calculate_f1_score_empty_inputs():
|
||||
"""测试空输入的情况"""
|
||||
# 空的检索结果和简单的SQL
|
||||
retrieval_res = []
|
||||
gold_sql = "SELECT a FROM table1"
|
||||
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# 检索结果为空时,precision和f1应该是0,recall也应该是0
|
||||
assert precision == 0.0
|
||||
assert recall == 0.0
|
||||
assert f1 == 0.0
|
||||
|
||||
# 空SQL的情况
|
||||
table1 = Mock()
|
||||
table1.name = "table1"
|
||||
table1.columns = {"a": None}
|
||||
|
||||
retrieval_res = [table1]
|
||||
gold_sql = ""
|
||||
|
||||
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
|
||||
|
||||
# SQL解析失败时应该返回0分
|
||||
assert precision == 0.0
|
||||
assert recall == 0.0
|
||||
assert f1 == 0.0
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
|
||||
from src.dative.core.evaluation.sql_generation import calculate_ex, calculate_f1_score, calculate_row_match
|
||||
|
||||
|
||||
class TestCalculateEx:
|
||||
"""测试 calculate_ex 函数"""
|
||||
|
||||
def test_exact_match(self):
|
||||
"""测试完全匹配的情况"""
|
||||
predicted = [(1, "A"), (2, "B")]
|
||||
ground_truth = [(1, "A"), (2, "B")]
|
||||
assert calculate_ex(predicted, ground_truth) == 1
|
||||
|
||||
def test_exact_match_different_order(self):
|
||||
"""测试顺序不同但内容相同的匹配"""
|
||||
predicted = [(2, "B"), (1, "A")]
|
||||
ground_truth = [(1, "A"), (2, "B")]
|
||||
assert calculate_ex(predicted, ground_truth) == 1
|
||||
|
||||
def test_no_match(self):
|
||||
"""测试完全不匹配的情况"""
|
||||
predicted = [(1, "A"), (2, "B")]
|
||||
ground_truth = [(3, "C"), (4, "D")]
|
||||
assert calculate_ex(predicted, ground_truth) == 0
|
||||
|
||||
def test_partial_match(self):
|
||||
"""测试部分匹配的情况"""
|
||||
predicted = [(1, "A"), (2, "B")]
|
||||
ground_truth = [(1, "A"), (3, "C")]
|
||||
assert calculate_ex(predicted, ground_truth) == 0
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""测试空输入的情况"""
|
||||
assert calculate_ex([], []) == 1
|
||||
assert calculate_ex([(1, "A")], []) == 0
|
||||
assert calculate_ex([], [(1, "A")]) == 0
|
||||
|
||||
|
||||
class TestCalculateRowMatch:
|
||||
"""测试 calculate_row_match 函数"""
|
||||
|
||||
def test_perfect_match(self):
|
||||
"""测试行完全匹配的情况"""
|
||||
predicted = (1, "A", 3.14)
|
||||
ground_truth = (1, "A", 3.14)
|
||||
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
|
||||
assert match_pct == 1.0
|
||||
assert pred_only_pct == 0.0
|
||||
assert truth_only_pct == 0.0
|
||||
|
||||
def test_partial_match(self):
|
||||
"""测试行部分匹配的情况"""
|
||||
predicted = (1, "B", 3.14)
|
||||
ground_truth = (1, "A", 3.14)
|
||||
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
|
||||
# 2个值匹配(1和3.14),共3列
|
||||
assert match_pct == pytest.approx(2 / 3)
|
||||
assert pred_only_pct == pytest.approx(1 / 3)
|
||||
assert truth_only_pct == pytest.approx(1 / 3)
|
||||
|
||||
def test_no_match(self):
|
||||
"""测试行完全不匹配的情况"""
|
||||
predicted = (2, "B", 2.71)
|
||||
ground_truth = (1, "A", 3.14)
|
||||
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
|
||||
assert match_pct == 0.0
|
||||
assert pred_only_pct == 1.0
|
||||
assert truth_only_pct == 1.0
|
||||
|
||||
def test_duplicate_values(self):
|
||||
"""测试包含重复值的情况"""
|
||||
predicted = (1, 1, "A")
|
||||
ground_truth = (1, "A", "A")
|
||||
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
|
||||
# 1和A都存在于ground_truth中
|
||||
assert match_pct == 1.0
|
||||
assert pred_only_pct == 0.0
|
||||
assert truth_only_pct == 0.0
|
||||
|
||||
|
||||
class TestCalculateF1Score:
|
||||
"""测试 calculate_f1_score 函数"""
|
||||
|
||||
def test_perfect_match(self):
|
||||
"""测试完全匹配的情况"""
|
||||
predicted = [(1, "A"), (2, "B")]
|
||||
ground_truth = [(1, "A"), (2, "B")]
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
assert f1_score == 1.0
|
||||
|
||||
def test_perfect_match_different_order(self):
|
||||
"""测试顺序不同但内容相同的匹配"""
|
||||
predicted = [(2, "B"), (1, "A")]
|
||||
ground_truth = [(1, "A"), (2, "B")]
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
assert f1_score == 1.0
|
||||
|
||||
def test_empty_results(self):
|
||||
"""测试都为空的情况"""
|
||||
f1_score = calculate_f1_score([], [])
|
||||
assert f1_score == 1.0
|
||||
|
||||
def test_one_empty_result(self):
|
||||
"""测试一个为空的情况"""
|
||||
predicted = [(1, "A")]
|
||||
ground_truth = []
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
assert f1_score == 0.0
|
||||
|
||||
predicted = []
|
||||
ground_truth = [(1, "A")]
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
assert f1_score == 0.0
|
||||
|
||||
def test_partial_match(self):
|
||||
"""测试部分匹配的情况"""
|
||||
predicted = [(1, "A", 10), (3, "C", 30)]
|
||||
ground_truth = [(1, "A", 10), (2, "B", 20)]
|
||||
|
||||
# 第一行完全匹配(1.0),第二行部分匹配(部分值匹配)
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
# 需要确保返回的是一个合理的f1分数(0-1之间)
|
||||
assert 0.0 <= f1_score <= 1.0
|
||||
|
||||
def test_duplicate_rows(self):
|
||||
"""测试重复行的情况"""
|
||||
predicted = [(1, "A"), (1, "A"), (2, "B")]
|
||||
ground_truth = [(1, "A"), (2, "B"), (2, "B")]
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
# 重复项应该被去除,结果应该与去重后相同
|
||||
assert 0.0 <= f1_score <= 1.0
|
||||
|
||||
def test_different_row_lengths(self):
|
||||
"""测试行长度不同的情况"""
|
||||
predicted = [(1, "A", 10, "extra")]
|
||||
ground_truth = [(1, "A", 10)]
|
||||
# 这种情况可能产生意外结果,但函数应该能处理
|
||||
f1_score = calculate_f1_score(predicted, ground_truth)
|
||||
assert 0.0 <= f1_score <= 1.0
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
from sqlglot import exp
|
||||
|
||||
from dative.core.sql_inspection.check import SQLCheck
|
||||
|
||||
|
||||
class TestSQLCheck:
|
||||
@pytest.fixture
|
||||
def sql_check(self):
|
||||
return SQLCheck(dialect="mysql")
|
||||
|
||||
def test_is_query_with_select_statement(self, sql_check):
|
||||
"""测试SELECT语句是否被识别为查询"""
|
||||
sql = "SELECT * FROM users"
|
||||
assert sql_check.is_query(sql) is True
|
||||
|
||||
def test_is_query_with_insert_statement(self, sql_check):
|
||||
"""测试INSERT语句是否不被识别为查询"""
|
||||
sql = "INSERT INTO users (name) VALUES ('Alice')"
|
||||
assert sql_check.is_query(sql) is False
|
||||
|
||||
def test_is_query_with_exp_expression(self, sql_check):
|
||||
"""测试直接传入Expression对象的情况"""
|
||||
expression = exp.select("*").from_("users")
|
||||
assert sql_check.is_query(expression) is True
|
||||
|
||||
def test_syntax_valid_with_correct_sql(self, sql_check):
|
||||
"""测试语法正确的SQL"""
|
||||
sql = "SELECT id, name FROM users WHERE age > 18"
|
||||
result = sql_check.syntax_valid(sql)
|
||||
assert isinstance(result, exp.Expression)
|
||||
|
||||
def test_syntax_valid_with_incorrect_sql(self, sql_check):
|
||||
"""测试语法错误的SQL应抛出异常"""
|
||||
sql = "SELECT FROM WHERE"
|
||||
with pytest.raises(Exception):
|
||||
sql_check.syntax_valid(sql)
|
||||
|
||||
def test_mysql_dialect(self):
|
||||
"""测试MySQL方言"""
|
||||
sql_check = SQLCheck(dialect="mysql")
|
||||
sql = "SELECT * FROM users LIMIT 10"
|
||||
assert sql_check.is_query(sql) is True
|
||||
|
||||
def test_postgres_dialect(self):
|
||||
"""测试PostgreSQL方言"""
|
||||
sql_check = SQLCheck(dialect="postgres")
|
||||
sql = "SELECT * FROM users LIMIT 10"
|
||||
assert sql_check.is_query(sql) is True
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
from sqlglot import exp, parse_one
|
||||
|
||||
from dative.core.sql_inspection.optimization import SQLOptimization
|
||||
|
||||
|
||||
class TestSQLOptimization:
|
||||
"""SQLOptimization 类的单元测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mysql_optimizer(self):
|
||||
"""MySQL 5.7 版本的优化器"""
|
||||
return SQLOptimization(dialect="mysql", db_major_version=5)
|
||||
|
||||
@pytest.fixture
|
||||
def mysql8_optimizer(self):
|
||||
"""MySQL 8.0 版本的优化器"""
|
||||
return SQLOptimization(dialect="mysql", db_major_version=8)
|
||||
|
||||
@pytest.fixture
|
||||
def postgres_optimizer(self):
|
||||
"""PostgreSQL 优化器"""
|
||||
return SQLOptimization(dialect="postgres", db_major_version=13)
|
||||
|
||||
def test_cte_to_subquery(self):
|
||||
"""测试 CTE 转换为子查询的功能"""
|
||||
sql = """
|
||||
WITH cte1 AS (SELECT id, name FROM users WHERE age > 18)
|
||||
SELECT * \
|
||||
FROM cte1 \
|
||||
WHERE name LIKE 'A%' \
|
||||
"""
|
||||
expression = parse_one(sql)
|
||||
result = SQLOptimization.cte_to_subquery(expression)
|
||||
|
||||
# 验证 CTE 已被替换为子查询
|
||||
assert "WITH" not in result.sql()
|
||||
assert "users" in result.sql()
|
||||
assert "age > 18" in result.sql()
|
||||
|
||||
def test_optimize_in_limit_subquery(self):
|
||||
"""测试优化带 LIMIT 的 IN 子查询"""
|
||||
sql = """
|
||||
SELECT * \
|
||||
FROM orders
|
||||
WHERE user_id IN (SELECT id FROM users LIMIT 10) \
|
||||
"""
|
||||
expression = parse_one(sql)
|
||||
result = SQLOptimization.optimize_in_limit_subquery(expression)
|
||||
|
||||
# 验证 LIMIT 子查询已被重写
|
||||
subquery = result.find(exp.In).args["query"]
|
||||
assert subquery is not None
|
||||
# 确保新的子查询结构正确
|
||||
assert isinstance(subquery, exp.Subquery)
|
||||
|
||||
def test_fix_missing_group_by_when_agg_func_case1(self):
|
||||
"""测试修复 GROUP BY 缺失字段 - 情况1"""
|
||||
sql = "SELECT a, b FROM x GROUP BY a"
|
||||
expression = parse_one(sql)
|
||||
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
|
||||
|
||||
# 应该添加缺失的 GROUP BY 字段 b
|
||||
group_clause = result.args.get("group")
|
||||
assert group_clause is not None
|
||||
group_expressions = group_clause.expressions
|
||||
assert len(group_expressions) == 2
|
||||
|
||||
def test_fix_missing_group_by_when_agg_func_case2(self):
|
||||
"""测试修复 GROUP BY 缺失字段 - 情况2"""
|
||||
sql = "SELECT a, COUNT(b) FROM x"
|
||||
expression = parse_one(sql)
|
||||
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
|
||||
|
||||
# 应该自动添加 GROUP BY a
|
||||
group_clause = result.args.get("group")
|
||||
assert group_clause is not None
|
||||
|
||||
def test_fix_missing_group_by_when_agg_func_case3(self):
|
||||
"""测试修复 GROUP BY 缺失字段 - 情况3"""
|
||||
sql = "SELECT a FROM x ORDER BY MAX(b)"
|
||||
expression = parse_one(sql)
|
||||
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
|
||||
|
||||
# 应该添加 GROUP BY a
|
||||
group_clause = result.args.get("group")
|
||||
assert group_clause is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_string_sql(self, mysql_optimizer):
|
||||
"""测试 arun 方法处理字符串 SQL"""
|
||||
sql = "SELECT id, name FROM users"
|
||||
result = mysql_optimizer.arun(sql, result_num_limit=100)
|
||||
|
||||
assert "LIMIT 100" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_expression(self, mysql_optimizer):
|
||||
"""测试 arun 方法处理表达式对象"""
|
||||
expression = parse_one("SELECT id, name FROM users")
|
||||
result = mysql_optimizer.arun(expression, result_num_limit=50)
|
||||
|
||||
assert "LIMIT 50" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mysql_cte_handling_old_version(self, mysql_optimizer):
|
||||
"""测试 MySQL 低版本 CTE 处理(应转换为子查询)"""
|
||||
sql = """
|
||||
WITH cte AS (SELECT id FROM users)
|
||||
SELECT *
|
||||
FROM cte \
|
||||
"""
|
||||
result = mysql_optimizer.arun(sql)
|
||||
|
||||
# 在 MySQL 5.x 中,CTE 应被转换为子查询
|
||||
assert "WITH" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mysql_cte_handling_new_version(self, mysql8_optimizer):
|
||||
"""测试 MySQL 8.0+ CTE 处理(应保留 CTE)"""
|
||||
sql = """
|
||||
WITH cte AS (SELECT id FROM users)
|
||||
SELECT *
|
||||
FROM cte
|
||||
join b
|
||||
on cte.id = b.id \
|
||||
"""
|
||||
result = mysql8_optimizer.arun(sql)
|
||||
# 在 MySQL 8.0+ 中,CTE 应被保留
|
||||
assert "WITH" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optimization_with_schema(self, postgres_optimizer):
|
||||
"""测试带模式信息的优化"""
|
||||
sql = "SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id"
|
||||
schema = {"users": {"id": "INT", "name": "VARCHAR"}, "orders": {"id": "INT", "user_id": "INT"}}
|
||||
result = postgres_optimizer.arun(sql, schema_type=schema)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dative.core.utils import (
|
||||
async_parallel_exec,
|
||||
cal_tokens,
|
||||
calculate_f1,
|
||||
exec_async_func,
|
||||
get_beijing_time,
|
||||
is_date,
|
||||
is_email,
|
||||
is_number,
|
||||
is_valid_uuid,
|
||||
parallel_exec,
|
||||
parse_real_cols,
|
||||
text2json,
|
||||
truncate_text,
|
||||
truncate_text_by_byte,
|
||||
)
|
||||
from src.dative.core.utils import convert_value2str
|
||||
|
||||
|
||||
def test_text2json_valid_json():
|
||||
text = '{"name": "Alice", "age": 30}'
|
||||
expected = {"name": "Alice", "age": 30}
|
||||
assert text2json(text) == expected
|
||||
|
||||
|
||||
def test_text2json_invalid_json():
|
||||
text = "invalid json"
|
||||
expected = {}
|
||||
assert text2json(text) == expected
|
||||
|
||||
|
||||
def test_text2json_non_dict_result():
|
||||
text = '["not", "a", "dict"]'
|
||||
expected = {}
|
||||
assert text2json(text) == expected
|
||||
|
||||
|
||||
def test_case_001_usage_metadata_exists():
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
||||
msg.response_metadata = {}
|
||||
result = cal_tokens(msg)
|
||||
|
||||
assert result == (10, 20)
|
||||
|
||||
|
||||
def test_case_002_token_usage_in_response_metadata():
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = None
|
||||
msg.response_metadata = {"token_usage": {"input_tokens": 5, "output_tokens": 15}}
|
||||
result = cal_tokens(msg)
|
||||
assert result == (5, 15)
|
||||
|
||||
|
||||
def test_get_beijing_time_default_format():
|
||||
result = get_beijing_time()
|
||||
assert datetime.datetime.strptime(result, "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def test_get_beijing_time_custom_format():
|
||||
fmt = "%Y/%m/%d"
|
||||
result = get_beijing_time(fmt)
|
||||
assert datetime.datetime.strptime(result, fmt)
|
||||
|
||||
|
||||
def test_convert_value2str_string():
|
||||
assert convert_value2str("hello") == "hello"
|
||||
|
||||
|
||||
def test_convert_value2str_decimal():
|
||||
assert convert_value2str(Decimal("10.5")) == "10.5"
|
||||
|
||||
|
||||
def test_convert_value2str_none():
|
||||
assert convert_value2str(None) == ""
|
||||
|
||||
|
||||
def test_convert_value2str_other_types():
|
||||
assert convert_value2str(123) == "123"
|
||||
assert convert_value2str(45.67) == "45.67"
|
||||
assert convert_value2str(True) == "True"
|
||||
|
||||
|
||||
def test_is_valid_uuid():
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
assert is_valid_uuid(valid_uuid) is True
|
||||
assert is_valid_uuid("invalid-uuid") is False
|
||||
|
||||
|
||||
def test_is_number():
|
||||
assert is_number("123") is True
|
||||
assert is_number("123.45") is True
|
||||
assert is_number(123) is True
|
||||
assert is_number("not_a_number") is False
|
||||
assert is_number(None) is False
|
||||
|
||||
|
||||
def test_is_date():
|
||||
assert is_date("2023-01-01") is True
|
||||
assert is_date("Jan 1, 2023") is True
|
||||
assert is_date("invalid-date") is False
|
||||
|
||||
|
||||
def test_is_email():
|
||||
assert is_email("test@example.com") is True
|
||||
assert is_email("invalid-email") is False
|
||||
|
||||
|
||||
def test_truncate_text():
|
||||
text = "This is a long text for testing"
|
||||
assert truncate_text(text, max_length=10) == "This..."
|
||||
assert truncate_text(text, max_length=100) == text
|
||||
assert truncate_text(text, max_length=3) == "Thi"
|
||||
assert truncate_text(123, max_length=10) == 123
|
||||
|
||||
|
||||
def test_truncate_text_by_byte():
|
||||
text = "这是一个用于测试的长文本"
|
||||
result = truncate_text_by_byte(text, max_length=20)
|
||||
assert isinstance(result, str)
|
||||
assert len(result.encode("utf-8")) <= 20
|
||||
|
||||
# 测试英文文本
|
||||
english_text = "This is English text"
|
||||
result = truncate_text_by_byte(english_text, max_length=15)
|
||||
assert len(result.encode("utf-8")) <= 15
|
||||
|
||||
|
||||
def sample_function(x):
|
||||
return x * 2
|
||||
|
||||
|
||||
async def async_sample_function(x):
|
||||
await asyncio.sleep(0.01) # 模拟异步操作
|
||||
return x * 2
|
||||
|
||||
|
||||
def test_parallel_exec():
|
||||
data = [1, 2, 3, 4, 5]
|
||||
result = parallel_exec(sample_function, data)
|
||||
assert result == [2, 4, 6, 8, 10]
|
||||
|
||||
|
||||
def test_parallel_exec_with_kwargs():
|
||||
def func_with_kwargs(a, b):
|
||||
return a + b
|
||||
|
||||
data = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
result = parallel_exec(func_with_kwargs, data)
|
||||
assert result == [3, 7]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parallel_exec():
|
||||
data = [1, 2, 3, 4, 5]
|
||||
result = await async_parallel_exec(async_sample_function, data)
|
||||
assert result == [2, 4, 6, 8, 10]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parallel_exec_with_sync_function():
|
||||
data = [1, 2, 3, 4, 5]
|
||||
result = await async_parallel_exec(sample_function, data)
|
||||
assert result == [2, 4, 6, 8, 10]
|
||||
|
||||
|
||||
async def async_add(a, b):
|
||||
await asyncio.sleep(0.1) # 模拟异步操作
|
||||
return a + b
|
||||
|
||||
|
||||
def test_exec_async_func():
|
||||
result = exec_async_func(async_add, 1, 2)
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_exec_async_func_with_kwargs():
|
||||
result = exec_async_func(async_add, a=1, b=2)
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_exec_async_func_with_mixed_args():
|
||||
result = exec_async_func(async_add, 1, b=2)
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_calculate_f1_normal_case():
|
||||
"""测试正常情况下的F1计算"""
|
||||
precision, recall, f1_score = calculate_f1(10, 5, 3)
|
||||
assert precision == 10 / 15 # 0.6667
|
||||
assert recall == 10 / 13 # 0.7692
|
||||
expected_f1 = 2 * precision * recall / (precision + recall)
|
||||
assert f1_score == expected_f1
|
||||
|
||||
|
||||
def test_calculate_f1_zero_precision_and_recall():
|
||||
"""测试精确度和召回率都为0的情况"""
|
||||
precision, recall, f1_score = calculate_f1(0, 0, 0)
|
||||
assert precision == 0
|
||||
assert recall == 0
|
||||
assert f1_score == 0
|
||||
|
||||
|
||||
def test_calculate_f1_zero_precision():
|
||||
"""测试精确度为0的情况"""
|
||||
precision, recall, f1_score = calculate_f1(0, 5, 5)
|
||||
assert precision == 0
|
||||
assert recall == 0
|
||||
assert f1_score == 0
|
||||
|
||||
|
||||
def test_calculate_f1_zero_recall():
|
||||
"""测试召回率为0的情况"""
|
||||
precision, recall, f1_score = calculate_f1(0, 0, 5)
|
||||
assert precision == 0
|
||||
assert recall == 0
|
||||
assert f1_score == 0
|
||||
|
||||
|
||||
def test_parse_real_cols_simple_select():
|
||||
"""测试简单SELECT查询"""
|
||||
sql = "SELECT a, b FROM table1"
|
||||
result = parse_real_cols(sql)
|
||||
expected = {"table1.a", "table1.b"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_parse_real_cols_with_join():
|
||||
"""测试包含JOIN的查询"""
|
||||
sql = "SELECT t1.a, t2.b FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id"
|
||||
result = parse_real_cols(sql)
|
||||
expected = {"table1.a", "table1.id", "table2.b", "table2.id"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_parse_real_cols_with_cte():
|
||||
"""测试包含CTE的查询"""
|
||||
sql = """
|
||||
WITH cte1 AS (SELECT a FROM table1)
|
||||
SELECT cte1.a, table2.b
|
||||
FROM cte1
|
||||
JOIN table2 ON cte1.a = table2.a \
|
||||
"""
|
||||
result = parse_real_cols(sql)
|
||||
expected = {"table1.a", "table2.a", "table2.b"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_parse_real_cols_case_insensitive():
|
||||
"""测试大小写不敏感情况"""
|
||||
sql = "SELECT A, b FROM Table1"
|
||||
result = parse_real_cols(sql, case_insensitive=True)
|
||||
expected = {"table1.a", "table1.b"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_parse_real_cols_case_sensitive():
|
||||
"""测试大小写敏感情况"""
|
||||
sql = "SELECT A, b FROM Table1"
|
||||
result = parse_real_cols(sql, case_insensitive=False)
|
||||
expected = {"Table1.A", "Table1.b"}
|
||||
assert result == expected
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue