Coverage for /usr/local/lib/python3.14/site-packages/twinpad_backend/db.py: 100%
106 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-29 08:22 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-29 08:22 +0000
1import asyncio
2import os
3import logging
4import time
6from pymongo import MongoClient, AsyncMongoClient, errors
7from pymongo.asynchronous.collection import AsyncCollection
8from pymongo.asynchronous.database import AsyncDatabase
9from pymongo.collection import Collection
10from pymongo.database import Database
12MONGO_USERNAME = os.environ.get("MONGO_USERNAME", "mongo")
13MONGO_PASSWORD = os.environ.get("MONGO_PASSWORD", "change_password")
14MONGO_HOST = os.environ.get("MONGO_HOST", "localhost")
16MAX_RETRIES = 5
17RETRY_DELAY = 10
19MONGO_CONN_STRING = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}/?retryWrites=true&w=majority"
21logger = logging.getLogger("uvicorn.error")
23logger.info("Connecting to mongo database: %s @ %s", MONGO_USERNAME, MONGO_HOST)
26def connect_to_mongo():
27 attempt = 0
28 while attempt < MAX_RETRIES:
29 try:
30 client = MongoClient(MONGO_CONN_STRING)
31 client.admin.command("ping")
32 logger.info("Database connected")
33 return client
34 except errors.ConnectionFailure as e: # pragma: no cover
35 attempt += 1
36 logger.info("Attempt %d/%d failed: %s", attempt, MAX_RETRIES, str(e))
37 if attempt < MAX_RETRIES:
38 logger.info("New attempt in %d seconds...", RETRY_DELAY)
39 time.sleep(RETRY_DELAY)
40 logger.info("Cannot connect to database after %d retries", MAX_RETRIES) # pragma: no cover
41 return None # pragma: no cover
44client = connect_to_mongo()
45async_client = AsyncMongoClient(MONGO_CONN_STRING)
46# mongo_database = client.twinpad
48signals_database = client.signals
49signals_async_database = async_client.signals
50systems_database = client.systems
51systems_async_database = async_client.systems
52devices_states_database = client.devices_states
53devices_states_async_database = async_client.devices_states
55_collection_cache: dict[Database, dict[str, Collection]] = {}
56_async_collection_cache: dict[AsyncDatabase, dict[str, AsyncCollection]] = {}
59def get_async_database(name: str):
60 new_async_client = async_client
61 if async_client._loop != asyncio.get_event_loop():
62 new_async_client = AsyncMongoClient(MONGO_CONN_STRING)
63 return getattr(new_async_client, name, None)
66def get_collection(
67 database: Database, collection_name: str, create: bool = False, time_series: bool = False
68) -> Collection | None:
69 if database in _collection_cache:
70 if collection_name in (database_cache := _collection_cache[database]):
71 return database_cache[collection_name]
72 else:
73 _collection_cache[database] = {}
75 if collection_name in database.list_collection_names():
76 collection = database[collection_name]
77 _collection_cache[database][collection_name] = collection
78 return collection
80 if create:
81 if time_series:
82 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"})
83 else:
84 collection = database.create_collection(collection_name)
85 _collection_cache[database][collection_name] = collection
86 return collection
88 return None
91def get_collections_batch(
92 database: Database, collection_names: list[str], create: bool = False, time_series: bool = False
93) -> list[Collection | None]:
94 """Fetches a batch of collections from an existing database.
96 :param database: MongoDB database in which the collections are located.
97 :type database: Database
98 :param collection_names: Names of the wanted collections.
99 :type collection_names: list[str]
100 :param create: Whether or not to create the collection if not found, defaults to False.
101 :type create: bool, optional
102 :param time_series: Whether or not the created collection should be a time series, defaults to False.
103 :type time_series: bool, optional
104 :return: A list of the requested collections. If :py:attr:`create` was false and the collection was not found, element is None.
105 :rtype: list[Collection | None]
106 """
107 collections = []
108 database_collection_names = None
110 if database in _collection_cache:
111 database_cache = _collection_cache[database]
112 else:
113 _collection_cache[database] = database_cache = {}
115 for collection_name in collection_names:
116 if collection_name in database_cache:
117 collections.append(database_cache[collection_name])
118 continue
120 if database_collection_names is None:
121 database_collection_names = database.list_collection_names()
123 if collection_name in database_collection_names:
124 collection = database[collection_name]
125 _collection_cache[database][collection_name] = collection
126 collections.append(collection)
127 continue
129 if create:
130 if time_series:
131 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"})
132 else:
133 collection = database.create_collection(collection_name)
134 _collection_cache[database][collection_name] = collection
135 collections.append(collection)
136 continue
138 collections.append(None)
140 return collections
143async def get_async_collection(
144 database: AsyncDatabase, collection_name: str, create: bool = False, time_series: bool = False
145) -> AsyncCollection | None:
146 if database in _async_collection_cache:
147 if collection_name in (database_cache := _async_collection_cache[database]):
148 return database_cache[collection_name]
149 else:
150 _async_collection_cache[database] = {}
152 if collection_name in (await database.list_collection_names()):
153 collection = database[collection_name]
154 _async_collection_cache[database][collection_name] = collection
155 return collection
157 if create:
158 if time_series:
159 collection = await database.create_collection(collection_name, timeseries={"timeField": "timestamp"})
160 else:
161 collection = await database.create_collection(collection_name)
162 _async_collection_cache[database][collection_name] = collection
163 return collection
165 return None
168def get_signal_collection(signal_id: str, create=False):
169 return get_collection(signals_database, signal_id, create=create, time_series=True)
172def get_signal_collections_batch(signal_ids: list[str], create=False):
173 """Fetches signals collections in batch from the :py:const:`signals database <signals_database>`.
175 :param signal_ids: Signal IDs of the wanted signals.
176 :type signal_ids: list[str]
177 :param create: Whether or not to create the collection if it does not exist, it will be created as a time-series, defaults to False
178 :type create: bool, optional
179 :return: A list of the requested collections. If :py:attr:`create` was false and the collection was not found, element is None.
180 :rtype: list[Collection | None]
181 """
182 return get_collections_batch(signals_database, signal_ids, create=create, time_series=True)
185def signal_datasize():
186 return signals_database.command("dbstats")["dataSize"]
189users_collection = get_collection(systems_database, "users", create=True)
190users_collection.update_many({"is_active": {"$ne": None}}, {"$rename": {"is_active": "is_blocked"}})