Initial commit
This commit is contained in:
277
main.py
Executable file
277
main.py
Executable file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "mcp",
|
||||
# "pyyaml",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlopen
|
||||
|
||||
import mcp.server.stdio
|
||||
import mcp.types as types
|
||||
import yaml
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
|
||||
|
||||
class OpenAPIServer:
|
||||
def __init__(self, docs_path: str):
|
||||
self.docs_source = docs_path
|
||||
self.docs_path = Path(docs_path) if not self._is_url(docs_path) else None
|
||||
self.server = Server("openapi-docs")
|
||||
self.spec = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info(f"Initializing OpenAPI server with docs source: {docs_path}")
|
||||
self.load_spec()
|
||||
self.setup_handlers()
|
||||
|
||||
def _is_url(self, path: str) -> bool:
|
||||
"""Check if the path is a URL"""
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
def load_spec(self):
|
||||
"""Load OpenAPI spec from the specified file or URL"""
|
||||
self.logger.info(f"Loading OpenAPI spec from {self.docs_source}")
|
||||
|
||||
try:
|
||||
if self._is_url(self.docs_source):
|
||||
self._load_spec_from_url()
|
||||
else:
|
||||
self._load_spec_from_file()
|
||||
|
||||
if self.spec:
|
||||
self.logger.info(
|
||||
f"Successfully loaded OpenAPI spec from {self.docs_source}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to load OpenAPI spec from {self.docs_source}: {e}"
|
||||
)
|
||||
self.spec = None
|
||||
|
||||
def _load_spec_from_url(self):
|
||||
"""Load OpenAPI spec from a URL"""
|
||||
with urlopen(self.docs_source) as response:
|
||||
content = response.read().decode("utf-8")
|
||||
|
||||
# Try to determine format from URL or content
|
||||
if self.docs_source.lower().endswith((".yaml", ".yml")):
|
||||
self.spec = yaml.safe_load(content)
|
||||
elif self.docs_source.lower().endswith(".json"):
|
||||
self.spec = json.loads(content)
|
||||
else:
|
||||
# Try JSON first, then YAML
|
||||
try:
|
||||
self.spec = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
self.spec = yaml.safe_load(content)
|
||||
|
||||
def _load_spec_from_file(self):
|
||||
"""Load OpenAPI spec from a local file"""
|
||||
if not self.docs_path.exists():
|
||||
self.logger.error(f"OpenAPI spec file not found: {self.docs_path}")
|
||||
return
|
||||
|
||||
with open(self.docs_path) as f:
|
||||
if self.docs_path.suffix.lower() in [".yaml", ".yml"]:
|
||||
self.spec = yaml.safe_load(f)
|
||||
elif self.docs_path.suffix.lower() == ".json":
|
||||
self.spec = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {self.docs_path.suffix}")
|
||||
|
||||
def setup_handlers(self):
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="search_endpoints",
|
||||
description="Search API endpoints by keyword",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search term"}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="get_endpoint",
|
||||
description="Get details for a specific endpoint",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"method": {"type": "string"},
|
||||
},
|
||||
"required": ["path", "method"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="list_all_endpoints",
|
||||
description="List all available API endpoints",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
self.logger.info(f"Tool called: {name} with arguments: {arguments}")
|
||||
|
||||
if name == "search_endpoints":
|
||||
query = arguments.get("query", "") if arguments else ""
|
||||
self.logger.debug(f"Searching endpoints with query: {query}")
|
||||
results = self.search_endpoints(query)
|
||||
self.logger.info(f"Found {len(results)} matching endpoints")
|
||||
return [
|
||||
types.TextContent(type="text", text=json.dumps(results, indent=2))
|
||||
]
|
||||
|
||||
elif name == "get_endpoint":
|
||||
path = arguments.get("path", "") if arguments else ""
|
||||
method = arguments.get("method", "") if arguments else ""
|
||||
self.logger.debug(f"Getting endpoint details for: {method} {path}")
|
||||
result = self.get_endpoint_details(path, method)
|
||||
return [
|
||||
types.TextContent(type="text", text=json.dumps(result, indent=2))
|
||||
]
|
||||
|
||||
elif name == "list_all_endpoints":
|
||||
self.logger.debug("Listing all endpoints")
|
||||
results = self.list_endpoints()
|
||||
self.logger.info(f"Found {len(results)} total endpoints")
|
||||
return [
|
||||
types.TextContent(type="text", text=json.dumps(results, indent=2))
|
||||
]
|
||||
|
||||
else:
|
||||
self.logger.error(f"Unknown tool requested: {name}")
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
def search_endpoints(self, query: str) -> List[Dict]:
|
||||
"""Search endpoints by keyword in path, summary, or description"""
|
||||
results = []
|
||||
if not self.spec or "paths" not in self.spec:
|
||||
return results
|
||||
|
||||
query_lower = query.lower()
|
||||
for path, methods in self.spec["paths"].items():
|
||||
if isinstance(methods, dict):
|
||||
for method, details in methods.items():
|
||||
if method in ["get", "post", "put", "delete", "patch"]:
|
||||
if (
|
||||
query_lower in path.lower()
|
||||
or query_lower in details.get("summary", "").lower()
|
||||
or query_lower in details.get("description", "").lower()
|
||||
or any(
|
||||
query_lower in tag.lower()
|
||||
for tag in details.get("tags", [])
|
||||
)
|
||||
):
|
||||
results.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method.upper(),
|
||||
"summary": details.get("summary", ""),
|
||||
"tags": details.get("tags", []),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def get_endpoint_details(self, path: str, method: str) -> Dict:
|
||||
"""Get full details for a specific endpoint"""
|
||||
if not self.spec or "paths" not in self.spec:
|
||||
return {"error": "No spec loaded"}
|
||||
|
||||
path_data = self.spec["paths"].get(path, {})
|
||||
method_data = path_data.get(method.lower(), {})
|
||||
|
||||
if method_data:
|
||||
return {"path": path, "method": method.upper(), "details": method_data}
|
||||
return {"error": f"Endpoint {method.upper()} {path} not found"}
|
||||
|
||||
def list_endpoints(self) -> List[Dict]:
|
||||
"""List all available endpoints"""
|
||||
results = []
|
||||
if not self.spec or "paths" not in self.spec:
|
||||
return results
|
||||
|
||||
for path, methods in self.spec["paths"].items():
|
||||
if isinstance(methods, dict):
|
||||
for method, details in methods.items():
|
||||
if method in ["get", "post", "put", "delete", "patch"]:
|
||||
results.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method.upper(),
|
||||
"summary": details.get("summary", ""),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
async def run(self):
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await self.server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="openapi-docs",
|
||||
server_version="0.1.0",
|
||||
capabilities=self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting OpenAPI MCP Server")
|
||||
|
||||
parser = argparse.ArgumentParser(description="OpenAPI MCP Server")
|
||||
parser.add_argument(
|
||||
"docs_path",
|
||||
help="Full path to OpenAPI spec file (YAML or JSON) or URL to remote spec",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Set the logging level (default: INFO)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
logger.info(f"Log level set to {args.log_level}")
|
||||
|
||||
try:
|
||||
server = OpenAPIServer(args.docs_path)
|
||||
logger.info("Server initialized successfully")
|
||||
await server.run()
|
||||
except Exception as e:
|
||||
logger.error(f"Server failed to start: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user