151 lines
4.2 KiB
Python
151 lines
4.2 KiB
Python
"""工具模块测试"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from minenasai.agent.tools.basic import (
|
|
list_directory_tool,
|
|
python_eval_tool,
|
|
read_file_tool,
|
|
)
|
|
|
|
|
|
class TestReadFileTool:
|
|
"""读取文件工具测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_existing_file(self, tmp_path):
|
|
"""测试读取存在的文件"""
|
|
test_file = tmp_path / "test.txt"
|
|
test_file.write_text("Hello, World!\nLine 2\nLine 3")
|
|
|
|
result = await read_file_tool(str(test_file))
|
|
|
|
assert "error" not in result
|
|
assert "Hello, World!" in result["content"]
|
|
assert result["lines"] == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_nonexistent_file(self):
|
|
"""测试读取不存在的文件"""
|
|
result = await read_file_tool("/nonexistent/file.txt")
|
|
|
|
assert "error" in result
|
|
assert "不存在" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_with_max_lines(self, tmp_path):
|
|
"""测试最大行数限制"""
|
|
test_file = tmp_path / "long.txt"
|
|
test_file.write_text("\n".join([f"Line {i}" for i in range(100)]))
|
|
|
|
result = await read_file_tool(str(test_file), max_lines=10)
|
|
|
|
assert "error" not in result
|
|
assert "截断" in result["content"]
|
|
|
|
|
|
class TestListDirectoryTool:
|
|
"""列出目录工具测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_directory(self, tmp_path):
|
|
"""测试列出目录"""
|
|
# 创建测试文件
|
|
(tmp_path / "file1.txt").touch()
|
|
(tmp_path / "file2.py").touch()
|
|
(tmp_path / "subdir").mkdir()
|
|
|
|
result = await list_directory_tool(str(tmp_path))
|
|
|
|
assert "error" not in result
|
|
assert result["count"] == 3
|
|
names = [item["name"] for item in result["items"]]
|
|
assert "file1.txt" in names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_with_pattern(self, tmp_path):
|
|
"""测试模式匹配"""
|
|
(tmp_path / "test.py").touch()
|
|
(tmp_path / "test.txt").touch()
|
|
|
|
result = await list_directory_tool(str(tmp_path), pattern="*.py")
|
|
|
|
assert "error" not in result
|
|
assert result["count"] == 1
|
|
assert result["items"][0]["name"] == "test.py"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_nonexistent_directory(self):
|
|
"""测试列出不存在的目录"""
|
|
result = await list_directory_tool("/nonexistent/dir")
|
|
|
|
assert "error" in result
|
|
|
|
|
|
class TestPythonEvalTool:
|
|
"""Python 执行工具测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_math(self):
|
|
"""测试简单数学计算"""
|
|
result = await python_eval_tool("1 + 2 * 3")
|
|
|
|
assert "error" not in result
|
|
assert result["result"] == 7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_math_functions(self):
|
|
"""测试数学函数"""
|
|
result = await python_eval_tool("math.sqrt(16)")
|
|
|
|
assert "error" not in result
|
|
assert result["result"] == 4.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_operations(self):
|
|
"""测试列表操作"""
|
|
result = await python_eval_tool("sum([1, 2, 3, 4, 5])")
|
|
|
|
assert "error" not in result
|
|
assert result["result"] == 15
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocked_import(self):
|
|
"""测试阻止 import"""
|
|
result = await python_eval_tool("__import__('os')")
|
|
|
|
assert "error" in result
|
|
assert "不允许" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocked_exec(self):
|
|
"""测试阻止 exec"""
|
|
result = await python_eval_tool("exec('print(1)')")
|
|
|
|
assert "error" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocked_open(self):
|
|
"""测试阻止 open"""
|
|
result = await python_eval_tool("open('/etc/passwd')")
|
|
|
|
assert "error" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_syntax_error(self):
|
|
"""测试语法错误"""
|
|
result = await python_eval_tool("1 +")
|
|
|
|
assert "error" in result
|
|
assert "语法错误" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runtime_error(self):
|
|
"""测试运行时错误"""
|
|
result = await python_eval_tool("1 / 0")
|
|
|
|
assert "error" in result
|
|
assert "division by zero" in result["error"]
|