fix: 引入sqlglot修复sql语句解析异常的问题

This commit is contained in:
insistence
2025-02-21 15:40:47 +08:00
parent 729dc23a16
commit ca641055e0
4 changed files with 47 additions and 16 deletions

View File

@@ -3,6 +3,7 @@ import os
import sys
from dotenv import load_dotenv
from functools import lru_cache
from pydantic import computed_field
from pydantic_settings import BaseSettings
from typing import Literal
@@ -51,6 +52,13 @@ class DataBaseSettings(BaseSettings):
db_pool_recycle: int = 3600
db_pool_timeout: int = 30
@computed_field
@property
def sqlglot_parse_dialect(self) -> str:
if self.db_type == 'postgresql':
return 'postgres'
return self.db_type
class RedisSettings(BaseSettings):
"""

View File

@@ -2,6 +2,7 @@ from datetime import datetime, time
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlglot.expressions import Expression
from typing import List
from config.env import DataBaseConfig
from module_generator.entity.do.gen_do import GenTable, GenTableColumn
@@ -75,15 +76,17 @@ class GenTableDao:
return gen_table_all
@classmethod
async def create_table_by_sql_dao(cls, db: AsyncSession, sql: str):
async def create_table_by_sql_dao(cls, db: AsyncSession, sql_statements: List[Expression]):
"""
根据sql语句创建表结构
:param db: orm对象
:param sql: sql语句
:param sql_statements: sql语句的ast列表
:return:
"""
await db.execute(text(sql))
for sql_statement in sql_statements:
sql = sql_statement.sql(dialect=DataBaseConfig.sqlglot_parse_dialect)
await db.execute(text(sql))
@classmethod
async def get_gen_table_list(cls, db: AsyncSession, query_object: GenTablePageQueryModel, is_page: bool = False):

View File

@@ -1,13 +1,14 @@
import io
import json
import os
import re
import zipfile
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlglot import parse as sqlglot_parse
from sqlglot.expressions import Add, Alter, Create, Delete, Drop, Expression, Insert, Table, TruncateTable, Update
from typing import List
from config.constant import GenConstant
from config.env import GenConfig
from config.env import DataBaseConfig, GenConfig
from exceptions.exception import ServiceException
from module_admin.entity.vo.common_vo import CrudResponseModel
from module_admin.entity.vo.user_vo import CurrentUserModel
@@ -197,10 +198,11 @@ class GenTableService:
:param current_user: 当前用户信息对象
:return: 创建表结构结果
"""
if cls.__is_valid_create_table(sql):
sql_statements = sqlglot_parse(sql, dialect=DataBaseConfig.sqlglot_parse_dialect)
if cls.__is_valid_create_table(sql_statements):
try:
table_names = re.findall(r'create\s+table\s+(\w+)', sql, re.IGNORECASE)
await GenTableDao.create_table_by_sql_dao(query_db, sql)
table_names = cls.__get_table_names(sql_statements)
await GenTableDao.create_table_by_sql_dao(query_db, sql_statements)
gen_table_list = await cls.get_gen_db_table_list_by_name_services(query_db, table_names)
await cls.import_gen_table_services(query_db, gen_table_list, current_user)
@@ -211,22 +213,39 @@ class GenTableService:
raise ServiceException(message='建表语句不合法')
@classmethod
def __is_valid_create_table(cls, sql: str):
def __is_valid_create_table(cls, sql_statements: List[Expression]):
"""
校验sql语句是否为合法的建表语句
:param sql: sql语句
:param sql_statements: sql语句的ast列表
:return: 校验结果
"""
create_table_pattern = r'^\s*CREATE\s+TABLE\s+'
if not re.search(create_table_pattern, sql, re.IGNORECASE):
validate_create = [isinstance(sql_statement, Create) for sql_statement in sql_statements]
validate_forbidden_keywords = [
isinstance(
sql_statement,
(Add, Alter, Delete, Drop, Insert, TruncateTable, Update),
)
for sql_statement in sql_statements
]
if not any(validate_create) or any(validate_forbidden_keywords):
return False
forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER', 'TRUNCATE']
for keyword in forbidden_keywords:
if re.search(rf'\b{keyword}\b', sql, re.IGNORECASE):
return False
return True
@classmethod
def __get_table_names(cls, sql_statements: List[Expression]):
"""
获取sql语句中所有的建表表名
:param sql_statements: sql语句的ast列表
:return: 建表表名列表
"""
table_names = []
for sql_statement in sql_statements:
if isinstance(sql_statement, Create):
table_names.append(sql_statement.find(Table).name)
return table_names
@classmethod
async def preview_code_services(cls, query_db: AsyncSession, table_id: int):
"""

View File

@@ -14,4 +14,5 @@ PyMySQL==1.1.1
redis==5.2.1
requests==2.32.3
SQLAlchemy[asyncio]==2.0.38
sqlglot[rs]==26.6.0
user-agents==2.2.0