This commit is contained in:
sxf-xiongtao 2025-12-22 19:06:20 +08:00 committed by GitHub
commit 25bbc13506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 5528 additions and 0 deletions

14
plugins/dative/.gitignore vendored Normal file
View File

@ -0,0 +1,14 @@
__pycache__
.benchmarks
.idea
.mypy_cache
.pytest_cache
.ropeproject
.ruff_cache
.venv
.vscode
.zed
datasets/bird_minidev
.coverage
.DS_Store

34
plugins/dative/Dockerfile Normal file
View File

@ -0,0 +1,34 @@
FROM astral/uv:0.8-python3.11-bookworm-slim AS builder
RUN apt-get update && apt-get install -y --no-install-recommends gcc libc-dev
ENV UV_COMPILE_BYTECODE=1 UV_NO_INSTALLER_METADATA=1 UV_LINK_MODE=copy UV_PYTHON_DOWNLOADS=0
# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=scripts/install_duckdb_extensions.py,target=install_duckdb_extensions.py \
uv sync --frozen --no-install-project --no-dev \
&& uv run python install_duckdb_extensions.py
# clean the cache
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --frozen --no-install-project --no-dev
FROM python:3.11-slim-bookworm AS runtime
COPY --from=builder --chown=dative:dative /.venv /.venv
COPY --from=builder --chown=dative:dative /root/.duckdb /root/.duckdb
COPY src/dative /dative
COPY docker/entrypoint.sh /
ENV PATH="/.venv/bin:$PATH"
EXPOSE 3000
ENTRYPOINT ["/entrypoint.sh"]

44
plugins/dative/Justfile Normal file
View File

@ -0,0 +1,44 @@
# 安装项目依赖库
install:
uv sync
U:
uv sync --upgrade
up package:
uv sync --upgrade-package {{ package }}
# 项目代码风格格式化
format:
uv run ruff format src tests/unit
# 执行项目checklist
check:
uv run ruff check --fix src tests/unit
sqlfmt: install
uv run scripts/sqlfmt.py src/dative/core/data_source/metadata_sql
# 项目类型检查
type:
uv run mypy src
# 执行单元测试
test:
uv run pytest tests/unit
# 执行功能测试
test_function:
uv run pytest tests/function
# 执行集成测试
test_integration:
uv run pytest tests/integration
ci: install format check type test
dev: install
uv run fastapi dev src/dative/main.py
run: install
uv run fastapi run src/dative/main.py

32
plugins/dative/README.md Normal file
View File

@ -0,0 +1,32 @@
输入自然语言生成sql查询给出自然语言回答
## 特性:
- 支持mysql、postgresql、sqlite等数据库
- 结合duckdb强大本地数据库管理能力支持本地和S3存储结构化数据xlsx、xls、csv、xlsm、xlsb等
- 支持sql基本语法检查和优化
- 支持单个数据库查询,不支持跨数据库查询
- 仅支持sql查询不支持更新、删除、插入等语句
## 本地开发
1、项目管理工具使用 [uv](https://github.com/astral-sh/uv)使用pip安装
```bash
pip install uv
```
2、安装依赖包
```
uv sync
```
3、安装duckdb扩展包
```bash
uv run python scripts/install_duckdb_extensions.py
```
4、启动服务
```bash
uv run fastapi run src/dative/main.py
```

View File

@ -0,0 +1,51 @@
#!/bin/bash
set -e
# Function to check if environment is development (case-insensitive)
is_development() {
local env
env=$(echo "${DATIVE_ENVIRONMENT}" | tr '[:upper:]' '[:lower:]')
[[ "$env" == "development" ]]
}
start_arq() {
echo "Starting ARQ worker..."
# Use exec to replace the shell with arq_worker if running in single process mode
if ! is_development; then
exec arq_worker
else
arq_worker &
ARQ_PID=$!
echo "ARQ worker started with PID: $ARQ_PID"
fi
}
start_uvicorn() {
echo "Starting Uvicorn server..."
# Use PORT environment variable, default to 3000 if not set
local port="${PORT:-3000}"
# Use exec to replace the shell with uvicorn if running in single process mode
if ! is_development; then
exec uvicorn dative.main:app --host '0.0.0.0' --port "$port"
else
uvicorn dative.main:app --host '0.0.0.0' --port "$port" &
UVICORN_PID=$!
echo "Uvicorn server started with PID: $UVICORN_PID"
fi
}
# Start the appropriate service
if [[ "${MODE}" == "worker" ]]; then
start_arq
else
start_uvicorn
fi
if is_development; then
sleep infinity
else
echo "Signal handling disabled. Process will replace shell (exec)."
# In this case, start_arq or start_uvicorn would have used exec
# So we shouldn't reach here unless something went wrong
fi

View File

@ -0,0 +1,110 @@
[project]
name = "dative"
version = "0.1.0"
description = "Query data from database by natural language"
requires-python = "~=3.11.0"
dependencies = [
"aiobotocore>=2.24.2",
"aiosqlite>=0.21.0",
"asyncpg>=0.30.0",
"cryptography>=45.0.6",
"duckdb>=1.3.2",
"fastapi<1.0",
"fastexcel>=0.15.1",
"greenlet>=3.2.4",
"httpx[socks]>=0.28.1",
"json-repair>=0.50.1",
"langchain-core<1.0",
"langchain-openai<1.0",
"mysql-connector-python>=9.4.0",
"orjson>=3.11.2",
"pyarrow>=21.0.0",
"pydantic-settings>=2.10.1",
"python-dateutil>=2.9.0.post0",
"python-dotenv>=1.1.1",
"pytz>=2025.2",
"sqlglot[rs]==27.12.0",
"uvicorn[standard]<1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.sdist]
only-include = ["src/"]
[tool.hatch.build.targets.wheel]
packages = ["src/"]
[tool.uv]
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
extra-index-url = ["https://mirrors.aliyun.com/pypi/simple"]
[dependency-groups]
dev = [
"hatch>=1.14.1",
"ruff>=0.12.10",
"pytest-async>=0.1.1",
"pytest-asyncio>=0.26.0",
"pytest-cov>=6.3.0",
"pytest-mock>=3.14.0",
"pytest>=8.3.3",
"pytest-benchmark>=4.0.0",
"fastapi-cli>=0.0.8",
"mypy>=1.17.1",
"types-python-dateutil>=2.9.0.20250822",
"pyarrow-stubs>=20.0.0.20250825",
"asyncpg-stubs>=0.30.2",
"boto3-stubs>=1.40.25",
"types-aiobotocore[essential]>=2.24.2",
]
[tool.uv.sources]
# checklist
[tool.ruff]
line-length = 120
target-version = "py311"
preview = true
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # Pyflakes
"W", # Warning
"N", # PEP8 Naming
"I", # isort
"FAST", # FastAPI
]
# 禁用的规则
ignore = []
[tool.sqruff]
output_line_length = 120
max_line_length = 120
# 类型
[tool.mypy]
mypy_path = "./src"
exclude = ["tests/"]
# 单元测试
[tool.pytest.ini_options]
addopts = ["-rA", "--cov=dative"]
testpaths = ["tests/unit"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
python_files = ["*.py"]
# 代码覆盖率
[tool.coverage.run]
branch = true
parallel = true
omit = ["**/__init__.py"]
[tool.coverage.report]
# fail_under = 85
show_missing = true
sort = "cover"
exclude_lines = ["no cov", "if __name__ == .__main__.:"]

View File

@ -0,0 +1,8 @@
import duckdb
extensions = [
"excel",
"httpfs"
]
for ext in extensions:
duckdb.install_extension(ext)

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
import argparse
import sys
from pathlib import Path
from sqlglot import transpile
from sqlglot.dialects.dialect import DIALECT_MODULE_NAMES
from sqlglot.errors import ParseError
def transpile_sql_file(file: Path, dialect: str) -> None:
if dialect not in DIALECT_MODULE_NAMES:
raise ValueError(f"Dialect {dialect} not supported")
if not file.is_file() or not file.suffix == ".sql":
print("Please specify a sql file")
return
sql_txt = file.read_text(encoding="utf-8")
try:
sqls = transpile(sql_txt, read=dialect, write=dialect, pretty=True, indent=4, pad=4)
file.write_text(";\n\n".join(sqls) + "\n", encoding="utf-8")
except ParseError as e:
print(str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument("files_or_dirs", nargs='+', type=str, default=None, help="files or directories")
parser.add_argument('-d', '--dialect', type=str, help='输出文件')
args = parser.parse_args()
if args.files_or_dirs:
for file_or_dir in args.files_or_dirs:
fd = Path(file_or_dir)
if fd.is_file():
dialect = args.dialect or fd.stem.split(".")[-1]
transpile_sql_file(fd, dialect)
elif fd.is_dir():
for sql_file in fd.glob("*.sql"):
dialect = args.dialect or sql_file.stem.split(".")[-1]
transpile_sql_file(sql_file, dialect)
else:
print(f"{file_or_dir} is not a sql file or a directory")
sys.exit(3)
else:
print("Please specify a file or a directory")
sys.exit(3)

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from . import data_source
router = APIRouter()
router.include_router(data_source.router, prefix="/data_source", tags=["data_source"])

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
import os
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, SecretStr
from dative.core.data_source import DatabaseMetadata
class DatabaseConfigBase(BaseModel):
host: str = Field(description="host")
port: int = Field(default=3306, description="port")
username: str = Field(description="username")
password: str = Field(description="password")
class SqliteConfig(BaseModel):
type: Literal["sqlite"] = "sqlite"
db_path: str = Field(description="db_path")
class MysqlConfig(DatabaseConfigBase):
type: Literal["mysql", "maria"] = "mysql"
db_name: str = Field(description="数据库名称")
conn_pool_size: int = Field(default=3, description="数据库连接池大小")
class PostgresConfig(DatabaseConfigBase):
type: Literal["postgres"] = "postgres"
db_name: str = Field(description="数据库名称")
ns_name: str = Field(default="public", alias="schema", description="ns_name")
conn_pool_size: int = Field(default=3, description="数据库连接池大小")
class DuckdbLocalStoreConfig(BaseModel):
type: Literal["local"] = Field(default="local", description="数据库存储方式")
dir_path: str = Field(description="Excel文件目录")
class DuckdbS3StoreConfig(BaseModel):
type: Literal["s3"] = Field(default="s3", description="数据库存储方式")
host: str = Field(description="host")
port: int = Field(default=3306, description="port")
access_key: str = Field(description="access_key")
secret_key: str = Field(description="secret_key")
bucket: str = Field(description="bucket")
region: str = Field(default="", description="region")
use_ssl: bool = Field(default=False, description="use_ssl")
class DuckdbConfig(BaseModel):
type: Literal["duckdb"] = "duckdb"
store: DuckdbLocalStoreConfig | DuckdbS3StoreConfig = Field(description="数据库存储方式", discriminator="type")
DataSourceConfig = Union[SqliteConfig, MysqlConfig, PostgresConfig, DuckdbConfig]
def default_api_key() -> SecretStr:
if os.getenv("AIPROXY_API_TOKEN"):
api_key = SecretStr(str(os.getenv("AIPROXY_API_TOKEN")))
else:
api_key = SecretStr("")
return api_key
def default_base_url() -> str:
if os.getenv("AIPROXY_API_ENDPOINT"):
base_url = str(os.getenv("AIPROXY_API_ENDPOINT")) + "/v1"
else:
base_url = "https://api.openai.com/v1"
return base_url
class LLMInfo(BaseModel):
provider: Literal["openai"] = Field(default="openai", description="LLM提供者")
model: str = Field(description="LLM模型名称")
api_key: SecretStr = Field(default_factory=default_api_key, description="API密钥", examples=["sk-..."])
base_url: str = Field(
default_factory=default_base_url, description="API基础URL", examples=["https://api.openai.com/v1"]
)
temperature: float = Field(default=0.7, description="温度参数")
max_tokens: int | None = Field(default=None, description="最大生成长度")
extra_body: dict[str, Any] | None = Field(default=None)
class SqlQueryRequest(BaseModel):
source_config: DataSourceConfig = Field(description="数据库连接信息", discriminator="type")
sql: str
class SqlQueryResponse(BaseModel):
cols: list[str] = Field(default=list(), description="查询结果列")
data: list[tuple[Any, ...]] = Field(default=list(), description="查询结果数据")
class QueryByNLRequest(BaseModel):
source_config: DataSourceConfig = Field(description="数据库连接信息", discriminator="type")
query: str = Field(description="用户问题")
retrieved_metadata: DatabaseMetadata = Field(description="检索到的元数据")
generate_sql_llm: LLMInfo = Field(description="生成sql的LLM配置信息")
evaluate_sql_llm: LLMInfo | None = Field(default=None, description="评估sql的LLM配置信息")
schema_format: Literal["markdown", "m_schema"] = Field(default="markdown", description="schema转成prompt格式")
evidence: str = Field(default="", description="补充说明信息")
result_num_limit: int = Field(default=100, description="结果数量限制")
class QueryByNLResponse(BaseModel):
answer: str = Field(description="生成的答案")
sql: str = Field(description="生成的SQL语句")
sql_res: SqlQueryResponse = Field(description="SQL查询结果")
input_tokens: int = Field(description="输入token数")
output_tokens: int = Field(description="输出token数")

View File

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
import traceback
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import ORJSONResponse
from langchain_openai import ChatOpenAI
from dative.api.v1.data_model import (
DataSourceConfig,
QueryByNLRequest,
QueryByNLResponse,
SqlQueryRequest,
SqlQueryResponse,
)
from dative.core.agent import Agent
from dative.core.data_source import (
DatabaseMetadata,
DataSourceBase,
DBServerVersion,
DuckdbLocalStore,
DuckdbS3Store,
Mysql,
Postgres,
Sqlite,
)
router = APIRouter()
async def valid_data_source_config(source_config: DataSourceConfig) -> DataSourceBase:
ds: DataSourceBase
if source_config.type == "mysql" or source_config.type == "maria":
ds = Mysql(
host=source_config.host,
port=source_config.port,
username=source_config.username,
password=source_config.password,
db_name=source_config.db_name,
)
elif source_config.type == "postgres":
ds = Postgres(
host=source_config.host,
port=source_config.port,
username=source_config.username,
password=source_config.password,
db_name=source_config.db_name,
schema=source_config.ns_name,
)
elif source_config.type == "duckdb":
if source_config.store.type == "local":
ds = DuckdbLocalStore(dir_path=source_config.store.dir_path)
elif source_config.store.type == "s3":
ds = DuckdbS3Store(
host=source_config.store.host,
port=source_config.store.port,
access_key=source_config.store.access_key,
secret_key=source_config.store.secret_key,
bucket=source_config.store.bucket,
region=source_config.store.region,
use_ssl=source_config.store.use_ssl,
)
else:
raise HTTPException(
status_code=400,
detail={
"msg": f"Unsupported duckdb storage type: {source_config.store}",
"error": f"Unsupported duckdb storage type: {source_config.store}",
},
)
elif source_config.type == "sqlite":
ds = Sqlite(source_config.db_path)
else:
raise HTTPException(
status_code=400,
detail={
"msg": f"Unsupported data source types: {source_config.type}",
"error": f"Unsupported data source types: {source_config.type}",
},
)
try:
await ds.conn_test()
return ds
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"msg": "Connection failed. Please check the connection information.",
"error": str(e),
},
)
@router.post("/conn_test", response_class=ORJSONResponse, dependencies=[Depends(valid_data_source_config)])
async def conn_test() -> str:
return "ok"
@router.post("/get_metadata", response_class=ORJSONResponse)
async def get_metadata(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DatabaseMetadata:
try:
return await ds.aget_metadata()
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"msg": "Failed to obtain database metadata",
"error": str(e),
},
)
@router.post("/get_metadata_with_value_examples", response_class=ORJSONResponse)
async def get_metadata_with_value_example(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DatabaseMetadata:
try:
return await ds.aget_metadata_with_value_examples()
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"msg": "Failed to obtain database metadata",
"error": str(e),
},
)
@router.post("/get_server_version", response_class=ORJSONResponse)
async def get_server_version(ds: Annotated[DataSourceBase, Depends(valid_data_source_config)]) -> DBServerVersion:
try:
return await ds.aget_server_version()
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"msg": "Failed to obtain database version information",
"error": str(e),
},
)
@router.post("/sql_query", response_class=ORJSONResponse)
async def sql_query(request: SqlQueryRequest) -> SqlQueryResponse:
ds = await valid_data_source_config(source_config=request.source_config)
cols, data, err = await ds.aexecute_raw_sql(sql=request.sql)
if err:
raise HTTPException(
status_code=400,
detail={
"msg": "Database query failed",
"error": err.msg,
},
)
return SqlQueryResponse(cols=cols, data=data)
@router.post("/query_by_nl", response_class=ORJSONResponse)
async def query_by_nl(request: QueryByNLRequest) -> QueryByNLResponse:
ds = await valid_data_source_config(source_config=request.source_config)
try:
server_version = await ds.aget_server_version()
except Exception as e:
print(traceback.format_exc())
raise HTTPException(
status_code=400,
detail={
"msg": "Connection failed. Please check the connection information.",
"error": str(e),
},
)
generate_sql_llm = ChatOpenAI(
model=request.generate_sql_llm.model,
api_key=request.generate_sql_llm.api_key,
base_url=request.generate_sql_llm.base_url,
stream_usage=True,
extra_body=request.generate_sql_llm.extra_body
)
if request.evaluate_sql_llm:
evaluate_sql_llm = ChatOpenAI(
model=request.evaluate_sql_llm.model,
api_key=request.evaluate_sql_llm.api_key,
base_url=request.evaluate_sql_llm.base_url,
stream_usage=True,
extra_body=request.evaluate_sql_llm.extra_body
)
else:
evaluate_sql_llm = None
agent = Agent(
ds=ds,
db_server_version=server_version,
generate_sql_llm=generate_sql_llm,
result_num_limit=request.result_num_limit,
evaluate_sql_llm=evaluate_sql_llm,
)
try:
answer, sql, cols, data, total_input_tokens, total_output_tokens = await agent.arun(
query=request.query, metadata=request.retrieved_metadata, evidence=request.evidence
)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(
status_code=400,
detail={
"msg": "Database search failed",
"error": str(e),
},
)
return QueryByNLResponse(
answer=answer,
sql=sql,
sql_res=SqlQueryResponse(cols=cols, data=data),
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
)

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from typing import ClassVar
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=".env",
env_prefix="DATIVE_",
env_file_encoding="utf-8",
extra="ignore",
env_nested_delimiter="__",
nested_model_default_partial_update=True,
)
DATABASE_URL: str = ""
log_level: str = "INFO"
log_format: str = "plain"
redis_url: str = ""
settings = Settings()

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
from typing import Any, Literal, Optional
from langchain_core.language_models import BaseChatModel
from sqlglot.errors import SqlglotError
from . import sql_generation as sql_gen
from . import sql_res_evaluation as sql_eval
from .data_source import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError
from .sql_inspection import SQLCheck, SQLOptimization
class Agent:
def __init__(
self,
ds: DataSourceBase,
db_server_version: DBServerVersion,
generate_sql_llm: BaseChatModel,
result_num_limit: int = 100,
schema_format: Literal["markdown", "m_schema"] = "m_schema",
max_attempts: int = 2,
evaluate_sql_llm: Optional[BaseChatModel] = None,
):
self.ds = ds
self.generate_sql_llm = generate_sql_llm
self.schema_format = schema_format
self.db_server_version = db_server_version
self.sql_check = SQLCheck(ds.dialect)
self.sql_opt = SQLOptimization(dialect=ds.dialect, db_major_version=db_server_version.major)
self.result_num_limit = result_num_limit
self.max_attempts = max_attempts
self.evaluate_sql_llm = evaluate_sql_llm
async def arun(
self, query: str, metadata: DatabaseMetadata, evidence: str = ""
) -> tuple[str, str, list[str], list[tuple[Any, ...]], int, int]:
answer = None
sql = ""
cols: list[str] = []
data: list[tuple[Any, ...]] = []
total_input_tokens = 0
total_output_tokens = 0
attempts = 0
schema_str = f"# Database server info: {self.ds.dialect} {self.db_server_version}\n"
if self.schema_format == "markdown":
schema_str += metadata.to_markdown()
else:
schema_str += metadata.to_m_schema()
schema_type = {table.name: table.column_type() for table in metadata.tables if table.enabled}
current_sql, error_msg = None, None
while attempts < self.max_attempts:
answer, sql, input_tokens, output_tokens, error_msg = await sql_gen.arun(
query=query,
llm=self.generate_sql_llm,
db_info=schema_str,
evidence=evidence,
error_sql=current_sql,
error_msg=error_msg,
)
total_input_tokens += input_tokens
total_output_tokens += output_tokens
if answer is not None:
return answer, sql, cols, data, total_input_tokens, total_output_tokens
elif error_msg is None and sql:
try:
sql_exp = self.sql_check.syntax_valid(sql)
if not self.sql_check.is_query(sql_exp):
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens
sql = self.sql_opt.arun(sql_exp, schema_type=schema_type, result_num_limit=self.result_num_limit)
except SqlglotError as e:
error_msg = f"SQL语法错误{e}"
if not error_msg:
cols, data, err = await self.ds.aexecute_raw_sql(sql)
if err:
if err.error_type != SQLError.SyntaxError:
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens
else:
error_msg = err.msg
if not error_msg and self.evaluate_sql_llm:
answer, input_tokens, output_tokens, error_msg = await sql_eval.arun(
query=query,
llm=self.evaluate_sql_llm,
db_info=schema_str,
sql=sql,
res_rows=data,
evidence=evidence,
)
total_input_tokens += input_tokens
total_output_tokens += output_tokens
if not error_msg:
break
attempts += 1
current_sql = sql
return answer or "", sql, cols, data, total_input_tokens, total_output_tokens

View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
from .base import (
DatabaseMetadata,
DataSourceBase,
DBException,
DBServerVersion,
DBTable,
SQLError,
SQLException,
TableColumn,
TableForeignKey,
)
from .duckdb_ds import DuckdbLocalStore, DuckdbS3Store
from .mysql_ds import Mysql
from .postgres_ds import Postgres
from .sqlite_ds import Sqlite
__all__ = [
"DataSourceBase",
"DatabaseMetadata",
"DBServerVersion",
"TableColumn",
"TableForeignKey",
"DBTable",
"SQLError",
"SQLException",
"Mysql",
"Postgres",
"DBException",
"DuckdbLocalStore",
"DuckdbS3Store",
"Sqlite",
]

View File

@ -0,0 +1,288 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any, cast
import orjson
from pydantic import BaseModel, Field
from sqlglot import exp
from dative.core.utils import convert_value2str, is_date, is_email, is_number, is_valid_uuid, truncate_text
class TableColumn(BaseModel):
name: str = Field(description="列名")
type: str = Field(description="数据类型")
comment: str = Field(default="", description="描述信息")
auto_increment: bool = Field(default=False, description="是否自增")
nullable: bool = Field(default=True, description="是否允许为空")
default: Any = Field(default=None, description="默认值")
examples: list[Any] = Field(default_factory=list, description="值样例")
enabled: bool = Field(default=True, description="是否启用")
value_index: bool = Field(default=False, description="是否启用值索引")
class ConstraintKey(BaseModel):
name: str = Field(description="约束名")
column: str = Field(description="约束字段名")
class TableForeignKey(ConstraintKey):
referenced_schema: str = Field(description="引用schema")
referenced_table: str = Field(description="引用表名")
referenced_column: str = Field(description="引用字段名")
class DBTable(BaseModel):
name: str = Field(description="表名")
ns_name: str | None = Field(default=None, alias="schema", description="ns_name")
comment: str = Field(default="", description="描述信息")
columns: dict[str, TableColumn] = Field(default_factory=dict, description="")
primary_keys: list[str] = Field(default_factory=list, description="主键")
foreign_keys: list[TableForeignKey] = Field(default_factory=list, description="外键")
enabled: bool = Field(default=True, description="是否启用")
def column_type(self, case_insensitive: bool = True) -> dict[str, str]:
schema_type: dict[str, str] = dict()
for col in self.columns.values():
schema_type[col.name] = col.type
if case_insensitive:
schema_type[col.name.title()] = col.type
schema_type[col.name.lower()] = col.type
schema_type[col.name.upper()] = col.type
return schema_type
def to_markdown(self) -> str:
s = f"""
### Table name: {self.name}
{self.comment}
#### Columns:
|Name|Description|Type|Examples|
|---|---|---|---|
"""
for col in self.columns.values():
if not col.enabled:
continue
example_values = "<br>".join([convert_value2str(v) for v in col.examples])
s += f"|{col.name}|{col.comment}|{col.type}|{example_values}|\n"
if self.primary_keys:
s += f"#### Primary Keys: {tuple(self.primary_keys)}\n"
if self.foreign_keys:
fk_str = "#### Foreign Keys: \n"
for fk in self.foreign_keys:
# 只显示同一个schema的外键
if self.ns_name is None or self.ns_name == fk.referenced_schema:
fk_str += f" - {fk.column} -> {fk.referenced_table}.{fk.referenced_column}\n"
s += fk_str
return s
def to_m_schema(self, db_name: str | None = None) -> str:
# XiYanSQL: https://github.com/XGenerationLab/M-Schema
output = []
if db_name:
table_comment = f"# Table: {db_name}.{self.name}"
else:
table_comment = f"# Table: {self.name}"
if self.comment:
table_comment += f", {self.comment}"
output.append(table_comment)
field_lines = []
for col in self.columns.values():
if not col.enabled:
continue
field_line = f"({col.name}: {col.type.upper()}"
if col.name in self.primary_keys:
field_line += ", Primary Key"
if col.comment:
field_line += f", {col.comment.strip()}"
if col.examples:
example_values = ", ".join([convert_value2str(v) for v in col.examples])
field_line += f", Examples: [{example_values}]"
field_line += ")"
field_lines.append(field_line)
output.append("[")
output.append(",\n".join(field_lines))
output.append("]")
return "\n".join(output)
class SQLError(StrEnum):
EmptySQL = auto()
SyntaxError = auto()
NotAllowedOperation = auto()
DBError = auto()
UnknownError = auto()
class SQLException(BaseModel):
error_type: SQLError
msg: str
code: int = Field(default=0, description="错误码")
class DatabaseMetadata(BaseModel):
name: str = Field(description="数据库名")
comment: str = Field(default="", description="描述信息")
tables: list[DBTable] = Field(default_factory=list, description="")
def to_markdown(self) -> str:
s = f"""
# Database name: {self.name}
{self.comment}
## Tables:
"""
for table in self.tables:
if table.enabled:
s += f"{table.to_markdown()}\n"
return s
def to_m_schema(self) -> str:
output = [f"【DB_ID】 {self.name}"]
if self.comment:
output.append(self.comment)
output.append("【Schema】")
output.append("schema format: (column name: data type, is primary key, comment, examples)")
fks = []
for t in self.tables:
if not t.enabled:
continue
output.append(t.to_m_schema(self.name))
for fk in t.foreign_keys:
# 只显示同一个schema的外键
if t.ns_name is None or t.ns_name == fk.referenced_schema:
fks.append(f"{t.name}.{fk.column}={fk.referenced_table}.{fk.referenced_column}")
if fks:
output.append("【Foreign Keys】")
output.extend(fks)
output.append("\n")
return "\n".join(output)
class DBServerVersion(BaseModel):
major: int = Field(description="主版本号")
minor: int = Field(description="次版本号")
patch: int | None = Field(default=None, description="修订号/补丁号")
def __str__(self) -> str:
s = f"{self.major}.{self.minor}"
if self.patch is None:
return s
return s + f".{self.patch}"
def __repr__(self) -> str:
return str(self)
class DBException(BaseModel):
code: int = Field(description="异常码")
msg: str = Field(description="异常信息")
class DataSourceBase(ABC):
@property
@abstractmethod
def dialect(self) -> str:
"""return the type of data source"""
raise NotImplementedError
@property
@abstractmethod
def string_types(self) -> set[str]:
"""return the string types of data source"""
raise NotImplementedError
@property
@abstractmethod
def json_array_agg_func(self) -> str:
"""
获取JSON数组聚合函数的SQL表达式
Returns:
str: 返回适用于当前数据库类型的JSON数组聚合函数名称
"""
raise NotImplementedError
async def conn_test(self) -> bool:
await self.aexecute_raw_sql("SELECT 1")
return True
@abstractmethod
async def aget_server_version(self) -> DBServerVersion:
"""get server version"""
@abstractmethod
async def aget_metadata(self) -> DatabaseMetadata:
"""get database metadata"""
@abstractmethod
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
""""""
@staticmethod
async def distinct_values_exp(table_name: str, col_name: str, limit: int = 3) -> exp.Expression:
sql_exp = (
exp.select(exp.column(col_name))
.distinct()
.from_(exp.to_identifier(table_name))
.where(exp.not_(exp.column(col_name).is_(exp.null())))
.limit(limit)
)
return sql_exp
async def aget_metadata_with_value_examples(
self, value_num: int = 3, value_str_max_length: int = 40
) -> DatabaseMetadata:
"""
获取数据库元数据及示例值信息
Args:
value_num (int): 每个字段需要获取的示例值数量默认为3
value_str_max_length (int): 字符串类型示例值的最大长度默认为40
Returns:
DatabaseMetadata: 包含数据库元数据和示例值信息的对象
"""
db_metadata = await self.aget_metadata()
for t in db_metadata.tables:
col_exp = []
for col in t.columns.values():
col_exp.append(await self.agg_distinct_values(t.name, col.name, value_num))
sql = exp.union(*col_exp, distinct=False).sql(dialect=self.dialect)
_, rows, _ = await self.aexecute_raw_sql(sql)
rows = cast(list[tuple[str, str]], rows)
for i in range(len(rows)):
col_name = rows[i][0]
if rows[i][1]:
examples = orjson.loads(cast(str, rows[i][1]))
str_examples = [
truncate_text(convert_value2str(v), max_length=value_str_max_length) for v in examples
]
t.columns[col_name].examples = str_examples
if (
t.columns[col_name].type in self.string_types
and not is_valid_uuid(str_examples[0])
and not is_number(str_examples[0])
and not is_date(str_examples[0])
and not is_email(str_examples[0])
):
t.columns[col_name].value_index = True
return db_metadata
async def agg_distinct_values(self, table_name: str, col_name: str, limit: int = 3) -> exp.Expression:
dis_exp = await self.distinct_values_exp(table_name, col_name, limit)
sql_exp = exp.select(
exp.Alias(this=exp.Literal.string(col_name), alias="name"),
exp.Alias(this=exp.func(self.json_array_agg_func, exp.column(col_name)), alias="examples"),
).from_(exp.Subquery(this=dis_exp, alias="t"))
return sql_exp

View File

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, cast
import duckdb
import orjson
import pyarrow as pa
from aiobotocore.session import get_session
from botocore.exceptions import ClientError
from fastexcel import read_excel
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
with open(Path(__file__).parent / "metadata_sql" / "duckdb.sql", encoding="utf-8") as f:
METADATA_SQL = f.read()
supported_file_extensions = {"xlsx", "xls", "xlsm", "xlsb", "csv", "json", "parquet"}
class DuckdbBase(DataSourceBase, ABC):
"""
duckdb polars对比polars目前不支持对excel的sql查询优化 懒加载apipl.scan_
在文件较大数据量较多时通过sql查询优化不用把数据全部加载到内存进行查询
"""
def __init__(self, db_name: str):
self.db_name = db_name
self.conn = self.get_conn()
self.loaded = False
@abstractmethod
def get_conn(self) -> duckdb.DuckDBPyConnection:
""""""
@property
def dialect(self) -> str:
return "duckdb"
@property
def string_types(self) -> set[str]:
return {"VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING"}
@property
def json_array_agg_func(self) -> str:
return "JSON_GROUP_ARRAY"
async def aget_server_version(self) -> DBServerVersion:
version = [int(i) for i in duckdb.__version__.split(".") if i.isdigit()]
server_version = DBServerVersion(major=version[0], minor=version[1])
if len(version) > 2:
server_version.patch = version[2]
return server_version
@abstractmethod
async def read_files(self) -> list[tuple[str, str]]:
""""""
async def load_data(self) -> None:
if self.loaded:
return
# 加载新数据,重置连接
self.conn = self.get_conn()
files = await self.read_files()
for name, file_path in files:
extension = Path(file_path).suffix.split(".")[-1].lower()
if extension == "xlsx":
sql = f"select * from read_xlsx('{file_path}')"
self.conn.register(f"{name}", self.conn.sql(sql))
elif extension == "csv":
sql = f"select * from read_csv('{file_path}')"
self.conn.register(f"{name}", self.conn.sql(sql))
elif extension == "json":
sql = f"select * from read_json('{file_path}')"
self.conn.register(f"{name}", self.conn.sql(sql))
elif extension == "parquet":
sql = f"select * from read_parquet('{file_path}')"
self.conn.register(f"{name}", self.conn.sql(sql))
else:
wb = read_excel(file_path)
record_batch = wb.load_sheet(wb.sheet_names[0]).to_arrow()
self.conn.register(f"{name}", pa.Table.from_batches([record_batch]))
self.loaded = True
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
try:
df: duckdb.DuckDBPyRelation = self.conn.sql(sql)
return df.columns, df.fetchall(), None
except duckdb.Error as e:
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))
async def aget_metadata(self) -> DatabaseMetadata:
await self.load_data()
_, rows, err = await self.aexecute_raw_sql(METADATA_SQL.format(db_name=self.db_name))
if err:
raise ConnectionError(err.msg)
if not rows or not rows[0][1]:
return DatabaseMetadata(name=self.db_name)
metadata = DatabaseMetadata.model_validate({
"name": self.db_name,
"tables": orjson.loads(cast(str, rows[0][1])),
})
return metadata
class DuckdbLocalStore(DuckdbBase):
def __init__(self, dir_path: str | Path):
if isinstance(dir_path, str):
dir_path = Path(dir_path)
super().__init__(dir_path.stem)
self.dir_path = dir_path
self.conn = self.get_conn()
self.loaded = False
def get_conn(self) -> duckdb.DuckDBPyConnection:
conn = duckdb.connect(config={"allow_unsigned_extensions": True})
conn.load_extension("excel")
return conn
async def conn_test(self) -> bool:
if self.dir_path.is_dir():
await self.load_data()
return True
return False
async def read_files(self) -> list[tuple[str, str]]:
excel_files: list[tuple[str, str]] = []
for file_path in self.dir_path.glob("*"):
if file_path.is_file() and file_path.suffix.split(".")[-1].lower() in supported_file_extensions:
excel_files.append((file_path.stem, str(file_path.absolute())))
return excel_files
class DuckdbS3Store(DuckdbBase):
def __init__(
self,
host: str,
port: int,
access_key: str,
secret_key: str,
bucket: str,
region: str = "",
use_ssl: bool = False,
):
self.host = host
self.port = port
self.endpoint = f"{self.host}:{self.port}"
self.access_key = access_key
self.secret_key = secret_key
self.bucket = bucket
self.db_name = bucket
self.region = region
self.use_ssl = use_ssl
self.conn = self.get_conn()
super().__init__(self.db_name)
if self.use_ssl:
self.endpoint_url = f"https://{self.endpoint}"
else:
self.endpoint_url = f"http://{self.endpoint}"
self.session = get_session()
def get_conn(self) -> duckdb.DuckDBPyConnection:
conn = duckdb.connect()
conn.load_extension("httpfs")
sql = f"""
CREATE OR REPLACE SECRET secret (
TYPE s3,
ENDPOINT '{self.endpoint}',
KEY_ID '{self.access_key}',
SECRET '{self.secret_key}',
USE_SSL {orjson.dumps(self.use_ssl).decode()},
URL_STYLE 'path',
REGION '{self.region!r}'
);
"""
conn.execute(sql)
return conn
async def conn_test(self) -> bool:
try:
async with self.session.create_client(
service_name="s3",
endpoint_url="http://localhost:9000",
aws_secret_access_key=self.secret_key,
aws_access_key_id=self.access_key,
region_name=self.region,
) as client:
await client.head_bucket(Bucket=self.bucket)
return True
except ClientError:
return False
async def read_files(self) -> list[tuple[str, str]]:
excel_files: list[tuple[str, str]] = []
async with self.session.create_client(
service_name="s3",
endpoint_url="http://localhost:9000",
aws_secret_access_key=self.secret_key,
aws_access_key_id=self.access_key,
region_name=self.region,
) as client:
paginator = client.get_paginator("list_objects_v2")
async for result in paginator.paginate(Bucket=self.bucket):
for obj in result.get("Contents", []):
if obj["Key"].split(".")[-1].lower() in supported_file_extensions:
excel_files.append((obj["Key"].split(".")[0], f"s3://{self.bucket}/{obj['Key']}"))
return excel_files

View File

@ -0,0 +1,19 @@
SELECT
'{db_name}' AS db_name,
JSON_GROUP_ARRAY(JSON_OBJECT('name', t.table_name, 'columns', t.columns)) AS tables
FROM (
SELECT
t.table_name,
JSON_GROUP_OBJECT(
t.column_name,
JSON_OBJECT(
'name', t.column_name,
'default', t.column_default,
'nullable', CASE WHEN t.is_nullable = 'YES' THEN TRUE ELSE FALSE END,
'type', UPPER(t.data_type)
)
) AS columns
FROM information_schema.columns AS t
GROUP BY
t.table_name
) AS t

View File

@ -0,0 +1,91 @@
SELECT
t.table_schema,
JSON_ARRAYAGG(
JSON_OBJECT(
'name', t.table_name,
'columns', t.columns,
'primary_keys', COALESCE(t.primary_keys, JSON_ARRAY()),
'foreign_keys', COALESCE(t.foreign_keys, JSON_ARRAY()),
'comment', t.table_comment
)
) AS tables
FROM (
SELECT
t1.table_schema,
t1.table_name,
t1.table_comment,
t2.columns,
t3.primary_keys,
t4.foreign_keys
FROM information_schema.tables AS t1
JOIN (
SELECT
t.table_schema,
t.table_name,
JSON_OBJECTAGG(
t.column_name, JSON_OBJECT(
'name', t.column_name,
'default', t.column_default,
'nullable', CASE WHEN t.is_nullable = 'YES' THEN CAST(TRUE AS JSON) ELSE CAST(FALSE AS JSON) END,
'type', UPPER(t.data_type),
'auto_increment', CASE
WHEN t.extra = 'auto_increment'
THEN CAST(TRUE AS JSON)
ELSE CAST(FALSE AS JSON)
END,
'comment', t.column_comment
)
) AS columns
FROM information_schema.columns AS t
WHERE
table_schema = '{db_name}'
GROUP BY
t.table_schema,
t.table_name
) AS t2
ON t1.table_schema = t2.table_schema AND t1.table_name = t2.table_name
LEFT JOIN (
SELECT
kcu.table_schema,
kcu.table_name,
JSON_ARRAYAGG(kcu.column_name) AS primary_keys
FROM information_schema.key_column_usage AS kcu
JOIN information_schema.table_constraints AS tc
ON kcu.table_schema = '{db_name}'
AND kcu.constraint_name = tc.constraint_name
AND kcu.constraint_schema = tc.constraint_schema
AND kcu.table_name = tc.table_name
AND tc.constraint_type = 'PRIMARY KEY'
GROUP BY
kcu.table_schema,
kcu.table_name
) AS t3
ON t1.table_schema = t3.table_schema AND t1.table_name = t3.table_name
LEFT JOIN (
SELECT
kcu.table_schema,
kcu.table_name,
JSON_ARRAYAGG(
JSON_OBJECT(
'name', kcu.constraint_name,
'column', kcu.column_name,
'referenced_schema', kcu.referenced_table_schema,
'referenced_table', kcu.referenced_table_name,
'referenced_column', kcu.referenced_column_name
)
) AS foreign_keys
FROM information_schema.key_column_usage AS kcu
JOIN information_schema.table_constraints AS tc
ON kcu.table_schema = '{db_name}'
AND kcu.constraint_name = tc.constraint_name
AND kcu.constraint_schema = tc.constraint_schema
AND kcu.table_name = tc.table_name
AND tc.constraint_type = 'FOREIGN KEY'
GROUP BY
kcu.table_schema,
kcu.table_name
) AS t4
ON t1.table_schema = t4.table_schema AND t1.table_name = t4.table_name
) AS t
GROUP BY
t.table_schema

View File

@ -0,0 +1,174 @@
SELECT
t.table_schema,
JSON_AGG(
JSON_BUILD_OBJECT(
'name',
t.table_name,
'schema',
t.table_schema,
'columns',
t.columns,
'primary_keys',
COALESCE(t.primary_keys, CAST('[]' AS JSON)),
'foreign_keys',
COALESCE(t.foreign_keys, CAST('[]' AS JSON)),
'comment',
t.comment
)
) AS tables
FROM (
SELECT
t1.table_schema,
t1.table_name,
t1.comment,
t2.columns,
t3.primary_keys,
t4.foreign_keys
FROM (
SELECT
n.nspname AS table_schema,
c.relname AS table_name,
COALESCE(OBJ_DESCRIPTION(c.oid), '') AS comment
FROM pg_class AS c
JOIN pg_namespace AS n
ON n.oid = c.relnamespace AND c.relkind = 'r' AND n.nspname = '{schema}'
) AS t1
JOIN (
SELECT
t.table_schema,
t.table_name,
JSON_OBJECT_AGG(
t.column_name,
JSON_BUILD_OBJECT(
'name',
t.column_name,
'default',
t.column_default,
'nullable',
CASE WHEN t.is_nullable = 'YES' THEN TRUE ELSE FALSE END,
'type',
UPPER(t.data_type),
'auto_increment',
t.auto_increment,
'comment',
COALESCE(t.comment, '')
)
) AS columns
FROM (
SELECT
c.table_schema,
c.table_name,
c.column_name,
c.data_type,
c.column_default,
c.is_nullable,
COALESCE(coldesc.comment, '') AS comment,
CASE
WHEN c.is_identity = 'YES'
OR c.column_default LIKE 'nextval%'
OR c.column_default LIKE 'nextval(%'
THEN TRUE
ELSE FALSE
END AS auto_increment
FROM information_schema.columns AS c
JOIN (
SELECT
n.nspname AS table_schema,
c.relname AS table_name,
a.attname AS column_name,
COL_DESCRIPTION(c.oid, a.attnum) AS comment
FROM pg_class AS c
JOIN pg_namespace AS n
ON n.oid = c.relnamespace
JOIN pg_attribute AS a
ON a.attrelid = c.oid
WHERE
a.attnum > 0 AND NOT a.attisdropped AND n.nspname = '{schema}'
) AS coldesc
ON c.table_schema = '{schema}'
AND c.table_schema = coldesc.table_schema
AND c.table_name = coldesc.table_name
AND c.column_name = coldesc.column_name
) AS t
GROUP BY
t.table_schema,
t.table_name
) AS t2
ON t1.table_schema = t2.table_schema AND t1.table_name = t2.table_name
LEFT JOIN (
SELECT
CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) AS table_schema,
CASE
WHEN POSITION('.' IN CAST(CAST(conrelid AS REGCLASS) AS TEXT)) > 0
THEN SPLIT_PART(CAST(CAST(conrelid AS REGCLASS) AS TEXT), '.', 2)
ELSE CAST(CAST(conrelid AS REGCLASS) AS TEXT)
END AS table_name,
TO_JSON(STRING_TO_ARRAY(SUBSTRING(PG_GET_CONSTRAINTDEF(oid) FROM '\((.*?)\)'), ',')) AS primary_keys
FROM pg_constraint
WHERE
contype = 'p' AND CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) = '{schema}'
) AS t3
ON t1.table_schema = t3.table_schema AND t1.table_name = t3.table_name
LEFT JOIN (
SELECT
t.table_schema,
t.table_name,
JSON_AGG(
JSON_BUILD_OBJECT(
'name',
t.name,
'column',
t.column_name,
'referenced_schema',
t.referenced_table_schema,
'referenced_table',
t.referenced_table_name,
'referenced_column',
t.referenced_column_name
)
) AS foreign_keys
FROM (
SELECT
c.conname AS name,
n.nspname AS table_schema,
CASE
WHEN POSITION('.' IN CAST(CAST(conrelid AS REGCLASS) AS TEXT)) > 0
THEN SPLIT_PART(CAST(CAST(conrelid AS REGCLASS) AS TEXT), '.', 2)
ELSE CAST(CAST(conrelid AS REGCLASS) AS TEXT)
END AS TABLE_NAME,
A.attname AS COLUMN_NAME,
nr.nspname AS referenced_table_schema,
CASE
WHEN POSITION('.' IN CAST(CAST(confrelid AS REGCLASS) AS TEXT)) > 0
THEN SPLIT_PART(CAST(CAST(confrelid AS REGCLASS) AS TEXT), '.', 2)
ELSE CAST(CAST(confrelid AS REGCLASS) AS TEXT)
END AS referenced_table_name,
af.attname AS referenced_column_name
FROM pg_constraint AS c
JOIN pg_attribute AS a
ON a.attnum = ANY(
c.conkey
) AND a.attrelid = c.conrelid
JOIN pg_class AS cl
ON cl.oid = c.conrelid
JOIN pg_namespace AS n
ON n.oid = cl.relnamespace
JOIN pg_attribute AS af
ON af.attnum = ANY(
c.confkey
) AND af.attrelid = c.confrelid
JOIN pg_class AS clf
ON clf.oid = c.confrelid
JOIN pg_namespace AS nr
ON nr.oid = clf.relnamespace
WHERE
c.contype = 'f' AND CAST(CAST(connamespace AS REGNAMESPACE) AS TEXT) = '{schema}'
) AS t
GROUP BY
t.table_schema,
t.table_name
) AS t4
ON t1.table_schema = t4.table_schema AND t1.table_name = t4.table_name
) AS t
GROUP BY
t.table_schema

View File

@ -0,0 +1,66 @@
SELECT
JSON_GROUP_ARRAY(
JSON_OBJECT(
'name', m.name,
'schema', 'main',
'columns', JSON(t1.columns),
'primary_keys', JSON(t2.primary_keys),
'foreign_keys', COALESCE(JSON(t3.foreign_keys), JSON_ARRAY())
)
) AS tables
FROM sqlite_master AS m
JOIN (
SELECT
m.name,
JSON_GROUP_OBJECT(
p.name,
JSON_OBJECT(
'name', p.name,
'type', CASE
WHEN INSTR(UPPER(p.type), '(') > 0
THEN SUBSTRING(UPPER(p.type), 1, INSTR(UPPER(p.type), '(') - 1)
ELSE UPPER(p.type)
END,
'nullable', (
CASE WHEN p."notnull" = 0 THEN TRUE ELSE FALSE END
),
'default', p.dflt_value
)
) AS columns
FROM sqlite_master AS m
JOIN PRAGMA_TABLE_INFO(m.name) AS p
ON m.type = 'table'
GROUP BY
m.name
) AS t1
ON m.name = t1.name
LEFT JOIN (
SELECT
m.name,
JSON_GROUP_ARRAY(p.name) AS primary_keys
FROM sqlite_master AS m
JOIN PRAGMA_TABLE_INFO(m.name) AS p
ON m.type = 'table' AND p.pk > 0
GROUP BY
m.name
) AS t2
ON m.name = t2.name
LEFT JOIN (
SELECT
m.name,
JSON_GROUP_ARRAY(
JSON_OBJECT(
'name', 'fk_' || m.tbl_name || '_' || fk."from" || '_' || fk."table" || '_' || fk."to",
'column', fk."from",
'referenced_schema', '',
'referenced_table', fk."table",
'referenced_column', fk."to"
)
) AS foreign_keys
FROM sqlite_master AS m
JOIN PRAGMA_FOREIGN_KEY_LIST(m.name) AS fk
ON m.type = 'table'
GROUP BY
m.tbl_name
) AS t3
ON m.name = t3.name

View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
import re
from pathlib import Path
from typing import Any, cast
import orjson
from mysql.connector.aio import MySQLConnection, connect
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
with open(Path(__file__).parent / "metadata_sql" / "mysql.sql", encoding="utf-8") as f:
METADATA_SQL = f.read()
version_p = re.compile(r"\b(v)?(\d+)\.(\d+)(?:\.(\d+))?(?:-([a-zA-Z0-9]+))?\b")
class Mysql(DataSourceBase):
def __init__(self, host: str, port: int, username: str, password: str, db_name: str):
self.host = host
self.port = port
self.username = username
self.password = password
self.db_name = db_name
@property
def dialect(self) -> str:
return "mysql"
@property
def string_types(self) -> set[str]:
return {
"CHAR",
"VARCHAR",
"TEXT",
"TINYTEXT",
"MEDIUMTEXT",
"LONGTEXT",
}
@property
def json_array_agg_func(self) -> str:
return "JSON_ARRAYAGG"
async def acreate_connection(self) -> MySQLConnection:
conn = await connect(
host=self.host, port=self.port, user=self.username, password=self.password, database=self.db_name
)
return cast(MySQLConnection, conn)
async def aget_server_version(self) -> DBServerVersion:
sql = "SELECT VERSION()"
_, rows, _ = await self.aexecute_raw_sql(sql)
version = cast(str, rows[0][0])
match = version_p.match(version)
if match:
has_v, major, minor, patch, suffix = match.groups()
server_version = DBServerVersion(major=int(major), minor=int(minor))
if patch:
server_version.patch = int(patch)
return server_version
else:
raise ValueError(f"Invalid version: {version}")
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
if not sql:
return [], [], SQLException(error_type=SQLError.EmptySQL, msg="SQL语句不能为空")
try:
conn = await self.acreate_connection()
except Exception as e:
return [], [], SQLException(error_type=SQLError.DBError, msg=e.args[1])
try:
cursor = await conn.cursor()
await cursor.execute(sql)
res = await cursor.fetchall()
if cursor.description:
cols = [i[0] for i in cursor.description]
else:
cols = []
await cursor.close()
await conn.close()
return cols, res, None
except Exception as e:
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=e.args[1], code=e.args[0])
async def aget_metadata(self) -> DatabaseMetadata:
sql = METADATA_SQL.format(db_name=self.db_name)
cols, rows, err = await self.aexecute_raw_sql(sql)
if err:
raise ConnectionError(err.msg)
if not rows or not rows[0][1]:
return DatabaseMetadata(name=self.db_name)
metadata = DatabaseMetadata.model_validate({
"name": rows[0][0],
"tables": orjson.loads(cast(str, rows[0][1])),
})
return metadata

View File

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
import re
from pathlib import Path
from typing import Any, cast
import asyncpg
import orjson
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
with open(Path(__file__).parent / "metadata_sql" / "postgres.sql", encoding="utf-8") as f:
METADATA_SQL = f.read()
version_pattern = re.compile(
r".*(?:PostgreSQL|EnterpriseDB) "
r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?"
)
class Postgres(DataSourceBase):
"""
postgres不支持跨数据库查询
"""
def __init__(self, host: str, port: int, username: str, password: str, db_name: str, schema: str = "public"):
self.host = host
self.port = port
self.username = username
self.password = password
self.db_name = db_name
self.schema = schema
@property
def dialect(self) -> str:
return "postgres"
@property
def string_types(self) -> set[str]:
return {"CHARACTER VARYING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"}
@property
def json_array_agg_func(self) -> str:
return "JSON_AGG"
async def acreate_connection(self) -> asyncpg.Connection:
conn = await asyncpg.connect(
host=self.host,
port=self.port,
user=self.username,
password=self.password,
database=self.db_name,
)
return conn
async def aget_server_version(self) -> DBServerVersion:
sql = "SELECT VERSION() as v"
_, rows, _ = await self.aexecute_raw_sql(sql)
version_str = cast(str, rows[0][0])
m = version_pattern.match(version_str)
if not m:
raise AssertionError("Could not determine version from string '%s'" % version_str)
version = [int(x) for x in m.group(1, 2, 3) if x is not None]
server_version = DBServerVersion(major=version[0], minor=version[1])
if len(version) > 2:
server_version.patch = version[2]
return server_version
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
if not sql:
return [], [], SQLException(error_type=SQLError.EmptySQL, msg="SQL语句不能为空")
try:
conn = await self.acreate_connection()
except Exception as e:
return [], [], SQLException(error_type=SQLError.DBError, msg=str(e))
try:
res: list[asyncpg.Record] = await conn.fetch(sql)
return list(res[0].keys()), [tuple(x.values()) for x in res], None
except Exception as e:
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))
async def aget_metadata(self) -> DatabaseMetadata:
sql = METADATA_SQL.format(schema=self.schema)
cols, rows, err = await self.aexecute_raw_sql(sql)
if err:
raise ConnectionError(err.msg)
if not rows or not rows[0][1]:
return DatabaseMetadata(name=self.db_name)
metadata = DatabaseMetadata.model_validate({
"name": self.db_name,
"tables": orjson.loads(cast(str, rows[0][1])),
})
return metadata

View File

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
from pathlib import Path
from typing import Any, cast
import aiosqlite
import orjson
from .base import DatabaseMetadata, DataSourceBase, DBServerVersion, SQLError, SQLException
with open(Path(__file__).parent / "metadata_sql" / "sqlite.sql", encoding="utf-8") as f:
METADATA_SQL = f.read()
class Sqlite(DataSourceBase):
def __init__(self, db_path: str | Path):
if isinstance(db_path, str):
db_path = Path(db_path)
self.db_path = db_path
self.db_name = db_path.stem
@property
def dialect(self) -> str:
return "sqlite"
@property
def string_types(self) -> set[str]:
return {"TEXT"}
@property
def json_array_agg_func(self) -> str:
return "JSON_GROUP_ARRAY"
async def conn_test(self) -> bool:
if self.db_path.is_file():
return True
return False
async def aget_server_version(self) -> DBServerVersion:
version = [int(i) for i in aiosqlite.sqlite_version.split(".") if i.isdigit()]
server_version = DBServerVersion(major=version[0], minor=version[1])
if len(version) > 2:
server_version.patch = version[2]
return server_version
async def aget_metadata(self) -> DatabaseMetadata:
_, rows, err = await self.aexecute_raw_sql(METADATA_SQL)
if err:
raise ConnectionError(err.msg)
if not rows or not rows[0][1]:
return DatabaseMetadata(name=self.db_name)
metadata = DatabaseMetadata.model_validate({
"name": self.db_name,
"tables": orjson.loads(cast(str, rows[0][0])),
})
return metadata
async def aexecute_raw_sql(self, sql: str) -> tuple[list[str], list[tuple[Any, ...]], SQLException | None]:
try:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(sql)
res = await cursor.fetchall()
return [i[0] for i in cursor.description], cast(list[tuple[Any, ...]], res), None
except Exception as e:
return [], [], SQLException(error_type=SQLError.SyntaxError, msg=str(e))

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
from ..data_source.base import DBTable
from ..utils import JOIN_CHAR, calculate_metrics, parse_real_cols
def calculate_retrieval_metrics(
tables: list[DBTable], gold_sql: str, dialect="mysql", case_insensitive: bool = True
) -> tuple[float, float, float]:
gold_cols = parse_real_cols(gold_sql, dialect, case_insensitive)
recall_cols = set()
for table in tables:
for cn in table.columns.keys():
col_name = f"{table.name}{JOIN_CHAR}{cn}"
if case_insensitive:
col_name = col_name.lower()
recall_cols.add(col_name)
tp = len(gold_cols & recall_cols)
fp = len(recall_cols - gold_cols)
fn = len(gold_cols - recall_cols)
return calculate_metrics(tp, fp, fn)

View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
from typing import Any
from ..utils import calculate_metrics
def calculate_ex(predicted_res: list[tuple[Any, ...]], ground_truth_res: list[tuple[Any, ...]]) -> int:
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res
def calculate_row_match(
predicted_row: tuple[Any, ...], ground_truth_row: tuple[Any, ...]
) -> tuple[float, float, float]:
"""
Calculate the matching percentage for a single row.
Args:
predicted_row (tuple): The predicted row values.
ground_truth_row (tuple): The actual row values from ground truth.
Returns:
float: The match percentage (0 to 1 scale).
"""
total_columns = len(ground_truth_row)
matches = 0
element_in_pred_only = 0
element_in_truth_only = 0
for pred_val in predicted_row:
if pred_val in ground_truth_row:
matches += 1
else:
element_in_pred_only += 1
for truth_val in ground_truth_row:
if truth_val not in predicted_row:
element_in_truth_only += 1
match_percentage = matches / total_columns
pred_only_percentage = element_in_pred_only / total_columns
truth_only_percentage = element_in_truth_only / total_columns
return match_percentage, pred_only_percentage, truth_only_percentage
def calculate_f1(
predicted: list[tuple[Any, ...]], ground_truth: list[tuple[Any, ...]]
) -> tuple[float, float, float]:
"""
Calculate the F1 score based on sets of predicted results and ground truth results,
where each element (tuple) represents a row from the database with multiple columns.
Args:
predicted (set of tuples): Predicted results from SQL query.
ground_truth (set of tuples): Actual results expected (ground truth).
Returns:
float: The calculated F1 score.
"""
# if both predicted and ground_truth are empty, return 1.0 for f1_score
if not predicted and not ground_truth:
return 1.0, 1.0, 1.0
# Calculate matching scores for each possible pair
match_scores: list[float] = []
pred_only_scores: list[float] = []
truth_only_scores: list[float] = []
for i, gt_row in enumerate(ground_truth):
# rows only in the ground truth results
if i >= len(predicted):
match_scores.append(0)
truth_only_scores.append(1)
continue
pred_row = predicted[i]
match_score, pred_only_score, truth_only_score = calculate_row_match(pred_row, gt_row)
match_scores.append(match_score)
pred_only_scores.append(pred_only_score)
truth_only_scores.append(truth_only_score)
# rows only in the predicted results
for i in range(len(predicted) - len(ground_truth)):
match_scores.append(0)
pred_only_scores.append(1)
truth_only_scores.append(0)
tp = sum(match_scores)
fp = sum(pred_only_scores)
fn = sum(truth_only_scores)
precision, recall, f1_score = calculate_metrics(tp, fp, fn)
return precision, recall, f1_score
def calculate_ves(
predicted_res: list[tuple[Any, ...]],
ground_truth_res: list[tuple[Any, ...]],
pred_cost_time:float,
ground_truth_cost_time:float
):
time_ratio = 0
if set(predicted_res) == set(ground_truth_res):
time_ratio = ground_truth_cost_time/pred_cost_time
if time_ratio == 0:
reward = 0
elif time_ratio >= 2:
reward = 1.25
elif 1 <= time_ratio < 2:
reward = 1
elif 0.5 <= time_ratio < 1:
reward = 0.75
elif 0.25 <= time_ratio < 0.5:
reward = 0.5
else:
reward = 0.25
return reward

View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
import asyncio
from json import JSONDecodeError
from typing import cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, HumanMessage
from ..utils import cal_tokens, get_beijing_time, text2json
prompt_prefix = """
You are a helpful data analyst who is great at thinking deeply and reasoning about the user's question and the database schema.
## 1. Database info
{db_info}
## 2. Context Information
- Current Time: {current_time}. It is only used when calculation or conversion is required according to the current system time.
## 3. Constraints
- Generate an optimized SQL query that directly answers the user's question.
- The SQL query must be fully formed, valid, and executable.
- Do NOT include any explanations, markdown formatting, or comments.
- If you want the maximum or minimum value, do not limit it to 1.
- When you need to calculate the proportion or other indicators, please use double precision.
## 4. Evidence
{evidence}
## 5. QUESTION
User's Question: {user_query}
""" # noqa: E501
generation_prompt = """
Respond in the following JSON format:
- If the user query is not related to the database, answer with empty string.
{{
"answer": ""
}}
- If you can answer the questions based on the database schema and don't need to generate SQL, generate the answers directly.
{{
"answer": "answer based on database schema"
}}
- If you need to answer the question by querying the database, please generate SQL, select only the necessary fields needed to answer the question, without any missing or extra information.
- Prefer using aggregate functions (such as COUNT, SUM, etc.) in SQL query and avoid returning redundant data for post-processing.
{{
"sql": "Generated SQL query here"
}}
""" # noqa: E501
correction_prompt = """
There is a SQL, but an error was reported after execution. Analyze why the given SQL query does not produce the correct results, identify the issues, and provide a corrected SQL query that properly answers the user's request.
- Current SQL Query: {current_sql}
- Error: {error}
**What You Need to Do:**
1. **Analyze:** Explain why the current SQL query fails to produce the correct results.
2. **Provide a Corrected SQL Query:** Write a revised query that returns the correct results.
Respond in the following JSON format:
{{
"sql": "The corrected SQL query should be placed here."
}}
""" # noqa: E501
async def arun(
query: str,
llm: BaseChatModel,
db_info: str,
evidence: str = "",
error_sql: str | None = None,
error_msg: str | None = None,
) -> tuple[str | None, str, int, int, str | None]:
answer = None
sql = ""
error = None
data = {
"db_info": db_info,
"user_query": query,
"current_time": await asyncio.to_thread(get_beijing_time),
"evidence": evidence or "",
}
input_tokens, output_tokens = 0, 0
if not error_sql:
prompt = prompt_prefix + generation_prompt
else:
assert error_msg, "error_msg is required when you need to correct sql"
data["current_sql"] = error_sql
data["error"] = error_msg
prompt = prompt_prefix + correction_prompt
try:
content = ""
async for chunk in llm.astream([HumanMessage(prompt.format(**data))]):
msg = cast(AIMessageChunk, chunk)
content += cast(str, msg.content)
input_tokens, output_tokens = await asyncio.to_thread(cal_tokens, msg)
except Exception as e:
raise ValueError(f"调用LLM失败{e}")
if content.startswith("SELECT"):
sql = content
elif content.find("```sql") != -1:
sql = content.split("```sql")[1].split("```")[0]
else:
try:
result = await asyncio.to_thread(text2json, content)
if "answer" in result:
answer = result.get("answer") or ""
else:
sql = result.get("sql") or ""
except JSONDecodeError:
error = "Incorrect json format"
return answer, sql, input_tokens, output_tokens, error

View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from .check import SQLCheck
from .optimization import SQLOptimization
__all__ = ["SQLCheck", "SQLOptimization"]

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
from sqlglot import exp, parse_one
class SQLCheck:
def __init__(self, dialect: str):
self.dialect = dialect
def is_query(self, sql_or_exp: str | exp.Expression) -> bool:
"""判断是否是查询语句"""
if isinstance(sql_or_exp, exp.Expression):
expression = sql_or_exp
else:
expression = parse_one(sql_or_exp, dialect=self.dialect)
return isinstance(expression, exp.Query)
def syntax_valid(self, sql: str) -> exp.Expression:
"""基本语法验证"""
return parse_one(sql, dialect=self.dialect)

View File

@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
from typing import Any, cast
from sqlglot import ParseError, exp, parse_one
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.optimizer import RULES, optimize
class SQLOptimization:
def __init__(self, dialect: str, db_major_version: int):
self.dialect = dialect
self.db_major_version = db_major_version
@staticmethod
def cte_to_subquery(expression: exp.Expression) -> exp.Expression:
# 收集所有 CTE
cte_names = {}
for cte in expression.find_all(exp.CTE):
alias = cte.alias
query = cte.this
cte_names[alias] = query
# 移除原始 CTE 节点
if cte.parent:
cte.parent.pop()
if not cte_names:
return expression
# 替换所有对 CTE 的引用为子查询
def replace_cte_with_subquery(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Table) and node.name in cte_names:
subquery = cte_names[node.name]
return exp.Subquery(this=subquery.copy(), alias=node.alias or node.name)
return node
return expression.transform(replace_cte_with_subquery)
@staticmethod
def optimize_in_limit_subquery(expression: exp.Expression) -> exp.Expression:
"""
Avoid 'LIMIT & IN/ALL/ANY/SOME subquery'
"""
for in_expr in expression.find_all(exp.In):
subquery = in_expr.args.get("query")
if not subquery:
continue
if subquery.this.find(exp.Limit):
t = subquery.this.args.get("from").this
if t.args.get("alias"):
alias = t.args.get("alias").this.this
else:
alias = exp.TableAlias(this=exp.Identifier(this="t"))
derived_table = exp.Subquery(this=subquery.this.copy(), alias=alias)
# 构建新的 SELECT t.id FROM (subquery) AS t
new_subquery_select = exp.select(*subquery.this.expressions).from_(derived_table)
# 替换 IN 的子查询
in_expr.set("query", exp.Subquery(this=new_subquery_select))
return expression
@staticmethod
def fix_missing_group_by_when_agg_func(expression: exp.Expression) -> exp.Expression:
"""
case2SELECT a, COUNT(b) FROM x --> SELECT a, b FROM x GROUP BY a
case3SELECT a FROM x ORDER BY MAX(b) --> SELECT a FROM x GROUP BY a ORDER BY MAX(b)
"""
for select_expr in expression.find_all(exp.Select):
select_agg = False
not_agg_query_cols = dict()
group_cols = dict()
order_by_agg = False
for col in select_expr.expressions:
if isinstance(col, exp.Column):
not_agg_query_cols[col.this.this] = col
elif isinstance(col, exp.AggFunc):
select_agg = True
elif isinstance(col, exp.Alias):
if isinstance(col.this, exp.Column):
not_agg_query_cols[col.this.this.this] = col.this
elif isinstance(col.this, exp.AggFunc):
select_agg = True
if expression.args.get("group"):
for col in expression.args["group"].expressions:
group_cols[col.this.this] = col
if expression.args.get("order"):
for order_col in expression.args["order"].expressions:
if isinstance(order_col.this, exp.AggFunc):
order_by_agg = True
if group_cols or (select_agg and not_agg_query_cols) or order_by_agg:
for col in not_agg_query_cols:
if col not in group_cols:
group_cols[col] = not_agg_query_cols[col]
if group_cols:
select_expr.set("group", exp.Group(expressions=group_cols.values()))
return expression
@staticmethod
def set_limit(expression: exp.Expression, result_num_limit: int):
if expression.args.get("limit"):
limit_exp = cast(exp.Limit, expression.args.get("limit"))
limit = min(result_num_limit, int(limit_exp.expression.this))
else:
limit = result_num_limit
expression.set("limit", exp.Limit(expression=exp.Literal.number(limit)))
def arun(
self,
sql_or_exp: str | exp.Expression,
schema_type: dict[str, dict[str, Any]] | None = None,
result_num_limit: int | None = None,
) -> str:
"""
Args:
sql_or_exp:
schema_type: db schema type, a mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
result_num_limit: int
Returns: str, optimized sql
"""
if isinstance(sql_or_exp, exp.Expression):
expression = sql_or_exp
else:
try:
expression = parse_one(sql_or_exp, dialect=self.dialect)
except ParseError:
expression = parse_one(sql_or_exp)
if result_num_limit and result_num_limit > 0:
self.set_limit(expression, result_num_limit)
rules = list(RULES)
if self.dialect == "mysql" and self.db_major_version < 8:
# mysql 8.0以上才支持 CTE否则只能用子查询
rules.remove(eliminate_subqueries)
rules.append(self.cte_to_subquery)
# 优化 in limit 子查询
rules.append(self.optimize_in_limit_subquery)
# 当聚合查询时修复sql中 GROUP BY 缺少的字段
rules.append(self.fix_missing_group_by_when_agg_func)
expression = optimize(
expression,
schema=schema_type,
dialect=self.dialect,
rules=rules, # type: ignore
identify=False,
)
return expression.sql(self.dialect)

View File

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
import asyncio
from json import JSONDecodeError
from typing import Any, cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from ..utils import cal_tokens, text2json
prompt = """
# Database info
{db_info}
# evidence:
{evidence}
# Original user query:
{query}
# Generated SQL:
{generated_sql}
# SQL query results:
{sql_result}
# Constraints:
- 1. Evaluate the relevance and quality of the SQL query results to the original user query.
- 2. If SQL results is empty or relevant, according to the above information directly answer the user's question, do not output "according to", "based on" and other redundant phrases.
Respond in the following JSON format:
{{
"final_answer": "Return a full natural language answer"
}}
- 3. If SQL results is not empty and not relevant, explain of why the SQL query is not relevant.
Respond in the following JSON format:
{{
"explanation": "explanation of why the SQL query is not relevant"
}}
""" # noqa:E501
async def arun(
query: str,
llm: BaseChatModel,
db_info: str,
sql: str,
res_rows: list[tuple[Any, ...]],
evidence: str | None = None,
) -> tuple[str, int, int, str | None]:
answer = ""
error = None
input_tokens, output_tokens = 0, 0
human_msg = prompt.format(
db_info=db_info,
query=query,
generated_sql=sql,
sql_result=res_rows,
evidence=evidence or "",
)
try:
content = ""
async for chunk in llm.astream([HumanMessage(human_msg)]):
msg = cast(AIMessage, chunk)
content += cast(str, msg.content)
input_tokens, output_tokens = await asyncio.to_thread(cal_tokens, msg)
except Exception as e:
raise ValueError(f"调用LLM失败{e}")
try:
result = await asyncio.to_thread(text2json, cast(str, content))
if "final_answer" in result:
answer = result["final_answer"] or ""
else:
error = result.get("explanation", "")
except JSONDecodeError:
error = "Incorrect json format"
return answer, input_tokens, output_tokens, error

View File

@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
import asyncio
import datetime
import functools
import inspect
import os
import re
import uuid
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
from typing import Any, Callable
import json_repair as json
from dateutil.parser import parse as date_parse
from langchain_core.messages import AIMessage
from sqlglot import exp, parse_one
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.scope import Scope, traverse_scope
JOIN_CHAR = "."
email_re = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
def text2json(text: str) -> dict[Any, Any]:
res = json.loads(text)
if isinstance(res, dict):
return res
return dict()
def get_beijing_time(fmt: str = "%Y-%m-%d %H:%M:%S") -> str:
utc_now = datetime.datetime.now(datetime.UTC)
beijing_now = utc_now + datetime.timedelta(hours=8)
formatted_time = beijing_now.strftime(fmt)
return formatted_time
def cal_tokens(msg: AIMessage) -> tuple[int, int]:
if msg.usage_metadata:
return msg.usage_metadata["input_tokens"], msg.usage_metadata["output_tokens"]
elif msg.response_metadata.get("token_usage", {}).get("input_tokens"):
token_usage = msg.response_metadata.get("token_usage", {})
input_tokens = token_usage.get("input_tokens", 0)
output_tokens = token_usage.get("output_tokens", 0)
return input_tokens, output_tokens
else:
return 0, 0
def convert_value2str(value: Any) -> str:
if isinstance(value, str):
return value
elif isinstance(value, Decimal):
return str(float(value))
elif value is None:
return ""
else:
return str(value)
def is_valid_uuid(s: str) -> bool:
try:
uuid.UUID(s)
return True
except ValueError:
return False
def is_number(value: Any) -> bool:
try:
float(value)
return True
except (ValueError, TypeError):
return False
def is_date(value: Any) -> bool:
try:
date_parse(value)
return True
except ValueError:
return False
def is_email(value: Any) -> bool:
if email_re.match(value):
return True
return False
def truncate_text(content: Any, *, max_length: int, suffix: str = "...") -> Any:
if not isinstance(content, str) or max_length <= 0:
return content
if len(content) <= max_length:
return content
if max_length <= len(suffix):
return content[:max_length]
# 确保截断后的文本不会超过最大长度,且不会在单词中间截断。
return content[: max_length - len(suffix)].rsplit(" ", 1)[0] + suffix
def truncate_text_by_byte(content: Any, *, max_length: int, suffix: str = "...", encoding: str = "utf-8") -> Any:
if not isinstance(content, str) or max_length <= 0:
return content
encoded = content.encode(encoding)
if len(encoded) <= max_length:
return content
suffix_bytes = suffix.encode(encoding)
available_bytes = max_length - len(suffix_bytes)
if available_bytes <= 0:
for i in range(max_length, 0, -1):
try:
return encoded[:i].decode(encoding)
except UnicodeDecodeError:
continue
return ""
# 尝试找到合适的截断位置,保证解码安全
for i in range(available_bytes, 0, -1):
try:
truncated_part = encoded[:i].decode(encoding)
return truncated_part + suffix
except UnicodeDecodeError:
continue
return suffix_bytes[:max_length].decode(encoding, errors="ignore")
async def async_parallel_exec(func: Callable[[Any], Any], data: list[Any], concurrency: int | None = None) -> list[Any]:
if not inspect.iscoroutinefunction(func):
async_func = functools.partial(asyncio.to_thread, func) # type: ignore
else:
async_func = func # type: ignore
if len(data) == 0:
return []
if concurrency is None:
concurrency = len(data)
semaphore = asyncio.Semaphore(concurrency)
async def worker(*args, **kwargs) -> Any:
async with semaphore:
return await async_func(*args, **kwargs) # type: ignore
t_list = []
for i in range(len(data)):
if isinstance(data[i], (list, tuple)):
t_list.append(worker(*data[i]))
elif isinstance(data[i], dict):
t_list.append(worker(**data[i]))
else:
t_list.append(worker(data[i]))
return await asyncio.gather(*t_list)
def parallel_exec(
func: Callable[[Any], Any], data: list[dict[Any, Any] | tuple[Any] | list[Any]], concurrency: int | None = None
) -> list[Any]:
if len(data) == 0:
return []
t_list = []
if concurrency is None:
concurrency = min(len(data), 32, (os.cpu_count() or 1) + 4)
with ThreadPoolExecutor(max_workers=concurrency) as pool:
for i in range(len(data)):
if isinstance(data[i], (list, tuple)):
t_list.append(pool.submit(func, *data[i]))
elif isinstance(data[i], dict):
t_list.append(pool.submit(func, **data[i])) # type: ignore
else:
t_list.append(pool.submit(func, data[i]))
return [t.result() for t in t_list]
def exec_async_func(func: Callable[[Any], Any], *args, **kwargs) -> Any:
if inspect.iscoroutinefunction(func):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
return loop.run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)
def parse_col_name_from_scope(
scope: Scope, real_cols: set[str], cte_names: set[str], case_insensitive: bool = True
) -> None:
if case_insensitive:
scope.sources = {k.lower(): v for k, v in scope.sources.items()}
for col in scope.columns:
if col.table:
tn = col.table
if case_insensitive:
tn = tn.lower()
if tn in scope.sources:
source = scope.sources[tn]
if isinstance(source, exp.Table) and tn not in cte_names:
src_tn = source.name
col_name = f"{src_tn}{JOIN_CHAR}{col.name}"
if case_insensitive:
col_name = col_name.lower()
real_cols.add(col_name)
elif not col.table and scope.tables:
tn = scope.tables[0].name
if tn not in cte_names:
col_name = f"{tn}{JOIN_CHAR}{col.name}"
if case_insensitive:
col_name = col_name.lower()
real_cols.add(col_name)
def parse_real_cols(sql: str, dialect="mysql", case_insensitive: bool = True) -> set[str]:
try:
parsed = parse_one(sql, dialect=dialect)
parsed = qualify(parsed, dialect=dialect)
cte_names = set()
for cte in parsed.find_all(exp.CTE):
cte_names.add(cte.alias)
scopes = traverse_scope(parsed)
real_cols: set[str] = set()
for scope in scopes:
parse_col_name_from_scope(scope, real_cols, cte_names, case_insensitive)
return real_cols
except Exception as e:
print(f"Error when parsing, error: {e}")
return set()
def calculate_metrics(tp: float, fp: float, fn: float) -> tuple[float, float, float]:
precision = tp / (tp + fp) if tp + fp > 0 else 0
recall = tp / (tp + fn) if tp + fn > 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return precision, recall, f1_score

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dative.api import v1
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(v1.router, prefix="/api/v1")
@app.get("/healthz")
async def healthz() -> str:
return "ok"
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=3000)

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
import os
import sys
import time
from typing import Callable
import pytest
# 将 src 目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
# 注: pytest.fixture 与 unittest.TestCase 类 混合使用不兼容
if os.name != "nt":
@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown():
# Setup logic
yield # 运行测试
# Teardown logic
def wait_until(condition_function: Callable, timeout=10):
start_time = time.time()
while not condition_function():
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError("Condition not met within the specified timeout")
time.sleep(0.01) # 短暂休眠以减少 CPU 使用

View File

@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
from unittest.mock import Mock
import pytest
from src.dative.core.evaluation.schema_retrieval import calculate_f1_score
def test_calculate_f1_score_perfect_match():
"""测试检索结果与gold SQL完全匹配的情况"""
# 创建模拟的DBTable对象
table1 = Mock()
table1.name = "table1"
table1.columns = {"a": None, "b": None} # columns.keys()返回['a', 'b']
table2 = Mock()
table2.name = "table2"
table2.columns = {"c": None}
retrieval_res = [table1, table2]
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2"
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# 完全匹配时precision, recall, f1都应该是1.0
assert precision == 1.0
assert recall == 1.0
assert f1 == 1.0
def test_calculate_f1_score_partial_match():
"""测试检索结果与gold SQL部分匹配的情况"""
# 创建模拟的DBTable对象
table1 = Mock()
table1.name = "table1"
table1.columns = {"a": None, "b": None}
retrieval_res = [table1] # 只检索到table1的列
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2" # gold SQL需要table1和table2的列
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# 精确度应该是1.0 (检索到的都正确)召回率应该是2/3 (只找到2个正确的共3个)
assert precision == 1.0
assert recall == pytest.approx(2 / 3)
assert f1 == pytest.approx(2 * (1.0 * 2 / 3) / (1.0 + 2 / 3))
def test_calculate_f1_score_no_match():
"""测试检索结果与gold SQL完全不匹配的情况"""
# 创建模拟的DBTable对象
table1 = Mock()
table1.name = "table3" # 与gold SQL中的表名不同
table1.columns = {"d": None, "e": None}
retrieval_res = [table1]
gold_sql = "SELECT table1.a, table1.b, table2.c FROM table1 JOIN table2"
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# 没有匹配时precision, recall, f1都应该是0.0
assert precision == 0.0
assert recall == 0.0
assert f1 == 0.0
def test_calculate_f1_score_case_insensitive():
"""测试大小写不敏感的情况"""
table1 = Mock()
table1.name = "Table1" # 大小写与gold SQL不同
table1.columns = {"A": None, "B": None}
retrieval_res = [table1]
gold_sql = "SELECT table1.a, table1.b FROM table1"
# 默认大小写不敏感
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# 应该匹配,因为大小写不敏感
assert precision == 1.0
assert recall == 1.0
assert f1 == 1.0
def test_calculate_f1_score_case_sensitive():
"""测试大小写敏感的情况"""
table1 = Mock()
table1.name = "Table1" # 大小写与gold SQL不同
table1.columns = {"A": None, "B": None}
retrieval_res = [table1]
gold_sql = "SELECT table1.a, table1.b FROM table1"
# 设置大小写敏感
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql, case_insensitive=False)
# 应该不匹配,因为大小写敏感
assert precision == 0.0
assert recall == 0.0
assert f1 == 0.0
def test_calculate_f1_score_empty_inputs():
"""测试空输入的情况"""
# 空的检索结果和简单的SQL
retrieval_res = []
gold_sql = "SELECT a FROM table1"
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# 检索结果为空时precision和f1应该是0recall也应该是0
assert precision == 0.0
assert recall == 0.0
assert f1 == 0.0
# 空SQL的情况
table1 = Mock()
table1.name = "table1"
table1.columns = {"a": None}
retrieval_res = [table1]
gold_sql = ""
precision, recall, f1 = calculate_f1_score(retrieval_res, gold_sql)
# SQL解析失败时应该返回0分
assert precision == 0.0
assert recall == 0.0
assert f1 == 0.0

View File

@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
import pytest
from src.dative.core.evaluation.sql_generation import calculate_ex, calculate_f1_score, calculate_row_match
class TestCalculateEx:
"""测试 calculate_ex 函数"""
def test_exact_match(self):
"""测试完全匹配的情况"""
predicted = [(1, "A"), (2, "B")]
ground_truth = [(1, "A"), (2, "B")]
assert calculate_ex(predicted, ground_truth) == 1
def test_exact_match_different_order(self):
"""测试顺序不同但内容相同的匹配"""
predicted = [(2, "B"), (1, "A")]
ground_truth = [(1, "A"), (2, "B")]
assert calculate_ex(predicted, ground_truth) == 1
def test_no_match(self):
"""测试完全不匹配的情况"""
predicted = [(1, "A"), (2, "B")]
ground_truth = [(3, "C"), (4, "D")]
assert calculate_ex(predicted, ground_truth) == 0
def test_partial_match(self):
"""测试部分匹配的情况"""
predicted = [(1, "A"), (2, "B")]
ground_truth = [(1, "A"), (3, "C")]
assert calculate_ex(predicted, ground_truth) == 0
def test_empty_inputs(self):
"""测试空输入的情况"""
assert calculate_ex([], []) == 1
assert calculate_ex([(1, "A")], []) == 0
assert calculate_ex([], [(1, "A")]) == 0
class TestCalculateRowMatch:
"""测试 calculate_row_match 函数"""
def test_perfect_match(self):
"""测试行完全匹配的情况"""
predicted = (1, "A", 3.14)
ground_truth = (1, "A", 3.14)
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
assert match_pct == 1.0
assert pred_only_pct == 0.0
assert truth_only_pct == 0.0
def test_partial_match(self):
"""测试行部分匹配的情况"""
predicted = (1, "B", 3.14)
ground_truth = (1, "A", 3.14)
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
# 2个值匹配(1和3.14)共3列
assert match_pct == pytest.approx(2 / 3)
assert pred_only_pct == pytest.approx(1 / 3)
assert truth_only_pct == pytest.approx(1 / 3)
def test_no_match(self):
"""测试行完全不匹配的情况"""
predicted = (2, "B", 2.71)
ground_truth = (1, "A", 3.14)
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
assert match_pct == 0.0
assert pred_only_pct == 1.0
assert truth_only_pct == 1.0
def test_duplicate_values(self):
"""测试包含重复值的情况"""
predicted = (1, 1, "A")
ground_truth = (1, "A", "A")
match_pct, pred_only_pct, truth_only_pct = calculate_row_match(predicted, ground_truth)
# 1和A都存在于ground_truth中
assert match_pct == 1.0
assert pred_only_pct == 0.0
assert truth_only_pct == 0.0
class TestCalculateF1Score:
"""测试 calculate_f1_score 函数"""
def test_perfect_match(self):
"""测试完全匹配的情况"""
predicted = [(1, "A"), (2, "B")]
ground_truth = [(1, "A"), (2, "B")]
f1_score = calculate_f1_score(predicted, ground_truth)
assert f1_score == 1.0
def test_perfect_match_different_order(self):
"""测试顺序不同但内容相同的匹配"""
predicted = [(2, "B"), (1, "A")]
ground_truth = [(1, "A"), (2, "B")]
f1_score = calculate_f1_score(predicted, ground_truth)
assert f1_score == 1.0
def test_empty_results(self):
"""测试都为空的情况"""
f1_score = calculate_f1_score([], [])
assert f1_score == 1.0
def test_one_empty_result(self):
"""测试一个为空的情况"""
predicted = [(1, "A")]
ground_truth = []
f1_score = calculate_f1_score(predicted, ground_truth)
assert f1_score == 0.0
predicted = []
ground_truth = [(1, "A")]
f1_score = calculate_f1_score(predicted, ground_truth)
assert f1_score == 0.0
def test_partial_match(self):
"""测试部分匹配的情况"""
predicted = [(1, "A", 10), (3, "C", 30)]
ground_truth = [(1, "A", 10), (2, "B", 20)]
# 第一行完全匹配(1.0),第二行部分匹配(部分值匹配)
f1_score = calculate_f1_score(predicted, ground_truth)
# 需要确保返回的是一个合理的f1分数(0-1之间)
assert 0.0 <= f1_score <= 1.0
def test_duplicate_rows(self):
"""测试重复行的情况"""
predicted = [(1, "A"), (1, "A"), (2, "B")]
ground_truth = [(1, "A"), (2, "B"), (2, "B")]
f1_score = calculate_f1_score(predicted, ground_truth)
# 重复项应该被去除,结果应该与去重后相同
assert 0.0 <= f1_score <= 1.0
def test_different_row_lengths(self):
"""测试行长度不同的情况"""
predicted = [(1, "A", 10, "extra")]
ground_truth = [(1, "A", 10)]
# 这种情况可能产生意外结果,但函数应该能处理
f1_score = calculate_f1_score(predicted, ground_truth)
assert 0.0 <= f1_score <= 1.0

View File

@ -0,0 +1,49 @@
import pytest
from sqlglot import exp
from dative.core.sql_inspection.check import SQLCheck
class TestSQLCheck:
@pytest.fixture
def sql_check(self):
return SQLCheck(dialect="mysql")
def test_is_query_with_select_statement(self, sql_check):
"""测试SELECT语句是否被识别为查询"""
sql = "SELECT * FROM users"
assert sql_check.is_query(sql) is True
def test_is_query_with_insert_statement(self, sql_check):
"""测试INSERT语句是否不被识别为查询"""
sql = "INSERT INTO users (name) VALUES ('Alice')"
assert sql_check.is_query(sql) is False
def test_is_query_with_exp_expression(self, sql_check):
"""测试直接传入Expression对象的情况"""
expression = exp.select("*").from_("users")
assert sql_check.is_query(expression) is True
def test_syntax_valid_with_correct_sql(self, sql_check):
"""测试语法正确的SQL"""
sql = "SELECT id, name FROM users WHERE age > 18"
result = sql_check.syntax_valid(sql)
assert isinstance(result, exp.Expression)
def test_syntax_valid_with_incorrect_sql(self, sql_check):
"""测试语法错误的SQL应抛出异常"""
sql = "SELECT FROM WHERE"
with pytest.raises(Exception):
sql_check.syntax_valid(sql)
def test_mysql_dialect(self):
"""测试MySQL方言"""
sql_check = SQLCheck(dialect="mysql")
sql = "SELECT * FROM users LIMIT 10"
assert sql_check.is_query(sql) is True
def test_postgres_dialect(self):
"""测试PostgreSQL方言"""
sql_check = SQLCheck(dialect="postgres")
sql = "SELECT * FROM users LIMIT 10"
assert sql_check.is_query(sql) is True

View File

@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
import pytest
from sqlglot import exp, parse_one
from dative.core.sql_inspection.optimization import SQLOptimization
class TestSQLOptimization:
"""SQLOptimization 类的单元测试"""
@pytest.fixture
def mysql_optimizer(self):
"""MySQL 5.7 版本的优化器"""
return SQLOptimization(dialect="mysql", db_major_version=5)
@pytest.fixture
def mysql8_optimizer(self):
"""MySQL 8.0 版本的优化器"""
return SQLOptimization(dialect="mysql", db_major_version=8)
@pytest.fixture
def postgres_optimizer(self):
"""PostgreSQL 优化器"""
return SQLOptimization(dialect="postgres", db_major_version=13)
def test_cte_to_subquery(self):
"""测试 CTE 转换为子查询的功能"""
sql = """
WITH cte1 AS (SELECT id, name FROM users WHERE age > 18)
SELECT * \
FROM cte1 \
WHERE name LIKE 'A%' \
"""
expression = parse_one(sql)
result = SQLOptimization.cte_to_subquery(expression)
# 验证 CTE 已被替换为子查询
assert "WITH" not in result.sql()
assert "users" in result.sql()
assert "age > 18" in result.sql()
def test_optimize_in_limit_subquery(self):
"""测试优化带 LIMIT 的 IN 子查询"""
sql = """
SELECT * \
FROM orders
WHERE user_id IN (SELECT id FROM users LIMIT 10) \
"""
expression = parse_one(sql)
result = SQLOptimization.optimize_in_limit_subquery(expression)
# 验证 LIMIT 子查询已被重写
subquery = result.find(exp.In).args["query"]
assert subquery is not None
# 确保新的子查询结构正确
assert isinstance(subquery, exp.Subquery)
def test_fix_missing_group_by_when_agg_func_case1(self):
"""测试修复 GROUP BY 缺失字段 - 情况1"""
sql = "SELECT a, b FROM x GROUP BY a"
expression = parse_one(sql)
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
# 应该添加缺失的 GROUP BY 字段 b
group_clause = result.args.get("group")
assert group_clause is not None
group_expressions = group_clause.expressions
assert len(group_expressions) == 2
def test_fix_missing_group_by_when_agg_func_case2(self):
"""测试修复 GROUP BY 缺失字段 - 情况2"""
sql = "SELECT a, COUNT(b) FROM x"
expression = parse_one(sql)
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
# 应该自动添加 GROUP BY a
group_clause = result.args.get("group")
assert group_clause is not None
def test_fix_missing_group_by_when_agg_func_case3(self):
"""测试修复 GROUP BY 缺失字段 - 情况3"""
sql = "SELECT a FROM x ORDER BY MAX(b)"
expression = parse_one(sql)
result = SQLOptimization.fix_missing_group_by_when_agg_func(expression)
# 应该添加 GROUP BY a
group_clause = result.args.get("group")
assert group_clause is not None
@pytest.mark.asyncio
async def test_arun_with_string_sql(self, mysql_optimizer):
"""测试 arun 方法处理字符串 SQL"""
sql = "SELECT id, name FROM users"
result = mysql_optimizer.arun(sql, result_num_limit=100)
assert "LIMIT 100" in result
@pytest.mark.asyncio
async def test_arun_with_expression(self, mysql_optimizer):
"""测试 arun 方法处理表达式对象"""
expression = parse_one("SELECT id, name FROM users")
result = mysql_optimizer.arun(expression, result_num_limit=50)
assert "LIMIT 50" in result
@pytest.mark.asyncio
async def test_mysql_cte_handling_old_version(self, mysql_optimizer):
"""测试 MySQL 低版本 CTE 处理(应转换为子查询)"""
sql = """
WITH cte AS (SELECT id FROM users)
SELECT *
FROM cte \
"""
result = mysql_optimizer.arun(sql)
# 在 MySQL 5.x 中CTE 应被转换为子查询
assert "WITH" not in result
@pytest.mark.asyncio
async def test_mysql_cte_handling_new_version(self, mysql8_optimizer):
"""测试 MySQL 8.0+ CTE 处理(应保留 CTE"""
sql = """
WITH cte AS (SELECT id FROM users)
SELECT *
FROM cte
join b
on cte.id = b.id \
"""
result = mysql8_optimizer.arun(sql)
# 在 MySQL 8.0+ 中CTE 应被保留
assert "WITH" in result
@pytest.mark.asyncio
async def test_optimization_with_schema(self, postgres_optimizer):
"""测试带模式信息的优化"""
sql = "SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id"
schema = {"users": {"id": "INT", "name": "VARCHAR"}, "orders": {"id": "INT", "user_id": "INT"}}
result = postgres_optimizer.arun(sql, schema_type=schema)
assert result is not None
assert isinstance(result, str)

View File

@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
import asyncio
import datetime
import uuid
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from dative.core.utils import (
async_parallel_exec,
cal_tokens,
calculate_f1,
exec_async_func,
get_beijing_time,
is_date,
is_email,
is_number,
is_valid_uuid,
parallel_exec,
parse_real_cols,
text2json,
truncate_text,
truncate_text_by_byte,
)
from src.dative.core.utils import convert_value2str
def test_text2json_valid_json():
text = '{"name": "Alice", "age": 30}'
expected = {"name": "Alice", "age": 30}
assert text2json(text) == expected
def test_text2json_invalid_json():
text = "invalid json"
expected = {}
assert text2json(text) == expected
def test_text2json_non_dict_result():
text = '["not", "a", "dict"]'
expected = {}
assert text2json(text) == expected
def test_case_001_usage_metadata_exists():
msg = MagicMock()
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 20}
msg.response_metadata = {}
result = cal_tokens(msg)
assert result == (10, 20)
def test_case_002_token_usage_in_response_metadata():
msg = MagicMock()
msg.usage_metadata = None
msg.response_metadata = {"token_usage": {"input_tokens": 5, "output_tokens": 15}}
result = cal_tokens(msg)
assert result == (5, 15)
def test_get_beijing_time_default_format():
result = get_beijing_time()
assert datetime.datetime.strptime(result, "%Y-%m-%d %H:%M:%S")
def test_get_beijing_time_custom_format():
fmt = "%Y/%m/%d"
result = get_beijing_time(fmt)
assert datetime.datetime.strptime(result, fmt)
def test_convert_value2str_string():
assert convert_value2str("hello") == "hello"
def test_convert_value2str_decimal():
assert convert_value2str(Decimal("10.5")) == "10.5"
def test_convert_value2str_none():
assert convert_value2str(None) == ""
def test_convert_value2str_other_types():
assert convert_value2str(123) == "123"
assert convert_value2str(45.67) == "45.67"
assert convert_value2str(True) == "True"
def test_is_valid_uuid():
valid_uuid = str(uuid.uuid4())
assert is_valid_uuid(valid_uuid) is True
assert is_valid_uuid("invalid-uuid") is False
def test_is_number():
assert is_number("123") is True
assert is_number("123.45") is True
assert is_number(123) is True
assert is_number("not_a_number") is False
assert is_number(None) is False
def test_is_date():
assert is_date("2023-01-01") is True
assert is_date("Jan 1, 2023") is True
assert is_date("invalid-date") is False
def test_is_email():
assert is_email("test@example.com") is True
assert is_email("invalid-email") is False
def test_truncate_text():
text = "This is a long text for testing"
assert truncate_text(text, max_length=10) == "This..."
assert truncate_text(text, max_length=100) == text
assert truncate_text(text, max_length=3) == "Thi"
assert truncate_text(123, max_length=10) == 123
def test_truncate_text_by_byte():
text = "这是一个用于测试的长文本"
result = truncate_text_by_byte(text, max_length=20)
assert isinstance(result, str)
assert len(result.encode("utf-8")) <= 20
# 测试英文文本
english_text = "This is English text"
result = truncate_text_by_byte(english_text, max_length=15)
assert len(result.encode("utf-8")) <= 15
def sample_function(x):
return x * 2
async def async_sample_function(x):
await asyncio.sleep(0.01) # 模拟异步操作
return x * 2
def test_parallel_exec():
data = [1, 2, 3, 4, 5]
result = parallel_exec(sample_function, data)
assert result == [2, 4, 6, 8, 10]
def test_parallel_exec_with_kwargs():
def func_with_kwargs(a, b):
return a + b
data = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
result = parallel_exec(func_with_kwargs, data)
assert result == [3, 7]
@pytest.mark.asyncio
async def test_async_parallel_exec():
data = [1, 2, 3, 4, 5]
result = await async_parallel_exec(async_sample_function, data)
assert result == [2, 4, 6, 8, 10]
@pytest.mark.asyncio
async def test_async_parallel_exec_with_sync_function():
data = [1, 2, 3, 4, 5]
result = await async_parallel_exec(sample_function, data)
assert result == [2, 4, 6, 8, 10]
async def async_add(a, b):
await asyncio.sleep(0.1) # 模拟异步操作
return a + b
def test_exec_async_func():
result = exec_async_func(async_add, 1, 2)
assert result == 3
def test_exec_async_func_with_kwargs():
result = exec_async_func(async_add, a=1, b=2)
assert result == 3
def test_exec_async_func_with_mixed_args():
result = exec_async_func(async_add, 1, b=2)
assert result == 3
def test_calculate_f1_normal_case():
"""测试正常情况下的F1计算"""
precision, recall, f1_score = calculate_f1(10, 5, 3)
assert precision == 10 / 15 # 0.6667
assert recall == 10 / 13 # 0.7692
expected_f1 = 2 * precision * recall / (precision + recall)
assert f1_score == expected_f1
def test_calculate_f1_zero_precision_and_recall():
"""测试精确度和召回率都为0的情况"""
precision, recall, f1_score = calculate_f1(0, 0, 0)
assert precision == 0
assert recall == 0
assert f1_score == 0
def test_calculate_f1_zero_precision():
"""测试精确度为0的情况"""
precision, recall, f1_score = calculate_f1(0, 5, 5)
assert precision == 0
assert recall == 0
assert f1_score == 0
def test_calculate_f1_zero_recall():
"""测试召回率为0的情况"""
precision, recall, f1_score = calculate_f1(0, 0, 5)
assert precision == 0
assert recall == 0
assert f1_score == 0
def test_parse_real_cols_simple_select():
"""测试简单SELECT查询"""
sql = "SELECT a, b FROM table1"
result = parse_real_cols(sql)
expected = {"table1.a", "table1.b"}
assert result == expected
def test_parse_real_cols_with_join():
"""测试包含JOIN的查询"""
sql = "SELECT t1.a, t2.b FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id"
result = parse_real_cols(sql)
expected = {"table1.a", "table1.id", "table2.b", "table2.id"}
assert result == expected
def test_parse_real_cols_with_cte():
"""测试包含CTE的查询"""
sql = """
WITH cte1 AS (SELECT a FROM table1)
SELECT cte1.a, table2.b
FROM cte1
JOIN table2 ON cte1.a = table2.a \
"""
result = parse_real_cols(sql)
expected = {"table1.a", "table2.a", "table2.b"}
assert result == expected
def test_parse_real_cols_case_insensitive():
"""测试大小写不敏感情况"""
sql = "SELECT A, b FROM Table1"
result = parse_real_cols(sql, case_insensitive=True)
expected = {"table1.a", "table1.b"}
assert result == expected
def test_parse_real_cols_case_sensitive():
"""测试大小写敏感情况"""
sql = "SELECT A, b FROM Table1"
result = parse_real_cols(sql, case_insensitive=False)
expected = {"Table1.A", "Table1.b"}
assert result == expected

2032
plugins/dative/uv.lock Normal file

File diff suppressed because it is too large Load Diff