diff --git a/srv/energyDB.py b/srv/energyDB.py index 3677e6e..ddb7a85 100644 --- a/srv/energyDB.py +++ b/srv/energyDB.py @@ -3,6 +3,7 @@ from fastapi import FastAPI, Header, HTTPException from pydantic import BaseModel import datetime import databases +import os import sqlalchemy from sqlite3 import OperationalError @@ -11,9 +12,6 @@ API_VERSION_MINOR = 0 REST_API_ROOT = f"/energy/v{API_VERSION_MAJOR}/" -DATABASE_URL = "sqlite:///./energyDB.sqlite" -# DATABASE_URL = "sqlite://" -db = databases.Database(DATABASE_URL) metadata = sqlalchemy.MetaData() energy = sqlalchemy.Table( "energy", @@ -23,9 +21,14 @@ energy = sqlalchemy.Table( sqlalchemy.Column("value", sqlalchemy.Float), sqlalchemy.UniqueConstraint("timestamp"), ) + +# DATABASE_URL = "sqlite:///./energyDB.sqlite" +DATABASE_URL = os.getenv("DATABASE_URL", default="sqlite://") +print(f"DB URL: {DATABASE_URL}") +db = databases.Database(DATABASE_URL) engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} -) + DATABASE_URL, + connect_args={"check_same_thread": False}) metadata.create_all(engine) class EnergyValue(BaseModel): @@ -85,8 +88,11 @@ async def getBulkEnergyData(bulkDataRequest: BulkDataRequest): .where(energy.c.channel_id == ch) \ .where(energy.c.timestamp >= bulkDataRequest.fromTime) \ .where(energy.c.timestamp <= bulkDataRequest.tillTime) - data = await db.fetch_all(query) - bulkData.append({"channel_id": ch, "data": data}) + try: + data = await db.fetch_all(query) + bulkData.append({"channel_id": ch, "data": data}) + except OperationalError as e: + raise HTTPException(status_code=500, detail="Database error") return { "bulk": bulkData, diff --git a/start_server.sh b/start_server.sh index 689dae1..7f4e37e 100755 --- a/start_server.sh +++ b/start_server.sh @@ -15,6 +15,6 @@ fi ARG_HTTP_PORT=${HTTP_PORT:-8000} ARG_IP_ADDRESS=${IP_ADDRESS:-127.0.0.1} -export DB_URL=${DATABASE_URL:-sqlite://} +export DATABASE_URL /usr/bin/env uvicorn --port $ARG_HTTP_PORT --host $ARG_IP_ADDRESS ${UVICORN_ARGS} srv:energyDB diff --git a/test/test_EnergyDB.py b/test/test_EnergyDB.py index 39fd002..2d20b30 100644 --- a/test/test_EnergyDB.py +++ b/test/test_EnergyDB.py @@ -3,8 +3,14 @@ import pytest from datetime import datetime import json +import os import urllib.parse +#TODO Use in-memory DB to test the case that there is no table +#TODO Add helper function to fill the in-memory DB before test + +os.environ["DATABASE_URL"] = "sqlite:///./energyDB.sqlite" + from srv import energyDB class Test_energyDb: @@ -66,6 +72,7 @@ class Test_energyDb: assert response.status_code == 200 def test_bulkData_get(self): + print(f"DB_URL: {os.getenv('DATABASE_URL')}") # response = self.client.put("/energy/bulkData", json=self.bulkTestData); fromTimestamp = datetime.fromisoformat("2020-12-11T12:30:00") tillTimestamp = datetime.fromisoformat("2020-12-11T12:30:59")