Decided to prototype some code. Let me get your thoughts @Rick
It is AI generated, which I just am using to think about ideas. It gives some scratchpad to actually talk about implementation in a system that actually does work. But don’t feel the need to read deeply (other than knowing the concept works and approximately what it is going to look like).
The concept is this: each command that is sent gets acknowledged with a task ID. You can check up on the task status. Upon completion, you get a new_lock, which allows you to run another command. This forces a sync between both systems: you must have gotten a successful completion on the last task to run another. Therefore, we can return any resource changes in that TaskStatusResponse, and ensure the client at least got the updates.
import asyncio
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, Any, Optional, List
import json
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
class TaskStatus(str, Enum):
RUNNING = "running"
COMPLETE = "complete"
FAILED = "failed"
class CommandRequest(BaseModel):
kwargs: Dict[str, Any] = {}
last_lock: Optional[str] = None
class CommandResponse(BaseModel):
cmd_running: str
class TaskStatusResponse(BaseModel):
status: TaskStatus
new_lock: Optional[str] = None
result: Optional[Any] = None
error: Optional[str] = None
class WebSocketMessage(BaseModel):
task_id: str
status: TaskStatus
new_lock: Optional[str] = None
result: Optional[Any] = None
error: Optional[str] = None
class Task:
def __init__(self, task_id: str, command: str, kwargs: Dict[str, Any]):
self.task_id = task_id
self.command = command
self.kwargs = kwargs
self.status = TaskStatus.RUNNING
self.created_at = datetime.now()
self.completed_at = None
self.result = None
self.error = None
self.lock = str(uuid.uuid4())
class TaskManager:
def __init__(self):
self.tasks: Dict[str, Task] = {}
self.task_lock = asyncio.Lock()
self.websocket_connections: List[WebSocket] = []
self.last_completed_lock: Optional[str] = None
async def create_task(self, command: str, kwargs: Dict[str, Any], last_lock: Optional[str] = None) -> str:
if last_lock is not None and last_lock != self.last_completed_lock:
raise HTTPException(status_code=400, detail="Invalid lock provided")
task_id = str(uuid.uuid4())
task = Task(task_id, command, kwargs)
self.tasks[task_id] = task
# Create task and store reference to prevent garbage collection
task_coro = asyncio.create_task(self._execute_task(task))
# Store the task reference to prevent it from being garbage collected
if not hasattr(self, '_background_tasks'):
self._background_tasks = set()
self._background_tasks.add(task_coro)
task_coro.add_done_callback(self._background_tasks.discard)
return task_id
async def get_task_status(self, task_id: str) -> TaskStatusResponse:
if task_id not in self.tasks:
raise HTTPException(status_code=404, detail="Task not found")
task = self.tasks[task_id]
return TaskStatusResponse(
status=task.status,
new_lock=task.lock if task.status == TaskStatus.COMPLETE else None,
result=task.result,
error=task.error
)
async def _execute_task(self, task: Task):
async with self.task_lock:
try:
if task.command == "sleep":
duration = task.kwargs.get("duration", 1.0)
await asyncio.sleep(duration)
task.result = f"Slept for {duration} seconds"
task.status = TaskStatus.COMPLETE
else:
task.error = f"Unknown command: {task.command}"
task.status = TaskStatus.FAILED
except Exception as e:
task.error = str(e)
task.status = TaskStatus.FAILED
task.completed_at = datetime.now()
if task.status == TaskStatus.COMPLETE:
self.last_completed_lock = task.lock
await self._notify_websockets(task)
async def _notify_websockets(self, task: Task):
if self.websocket_connections:
message = WebSocketMessage(
task_id=task.task_id,
status=task.status,
new_lock=task.lock if task.status == TaskStatus.COMPLETE else None,
result=task.result,
error=task.error
)
disconnected = []
for websocket in self.websocket_connections:
try:
await websocket.send_text(message.model_dump_json())
except:
disconnected.append(websocket)
for ws in disconnected:
self.websocket_connections.remove(ws)
def add_websocket(self, websocket: WebSocket):
self.websocket_connections.append(websocket)
def remove_websocket(self, websocket: WebSocket):
if websocket in self.websocket_connections:
self.websocket_connections.remove(websocket)
app = FastAPI(title="PLR API", description="PyLabRobot API for command execution")
task_manager = TaskManager()
@app.post("/plr_api/{cmd}", response_model=CommandResponse)
async def submit_command(cmd: str, request: CommandRequest):
task_id = await task_manager.create_task(cmd, request.kwargs, request.last_lock)
return CommandResponse(cmd_running=task_id)
@app.get("/plr_api/tasks/{task_id}", response_model=TaskStatusResponse)
async def get_task_status(task_id: str):
return await task_manager.get_task_status(task_id)
@app.websocket("/plr_api/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
task_manager.add_websocket(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
task_manager.remove_websocket(websocket)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)