Coverage for /usr/local/lib/python3.11/site-packages/twinpad_backend/db.py: 85%

100 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-05 08:18 +0000

1import os 

2import logging 

3import time 

4 

5from pymongo import MongoClient, AsyncMongoClient, errors 

6from pymongo.asynchronous.collection import AsyncCollection 

7from pymongo.asynchronous.database import AsyncDatabase 

8from pymongo.collection import Collection 

9from pymongo.database import Database 

10 

11MONGO_USERNAME = os.environ.get("MONGO_USERNAME", "mongo") 

12MONGO_PASSWORD = os.environ.get("MONGO_PASSWORD", "change_password") 

13MONGO_HOST = os.environ.get("MONGO_HOST", "localhost") 

14 

15MAX_RETRIES = 5 

16RETRY_DELAY = 10 

17 

18MONGO_CONN_STRING = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}/?retryWrites=true&w=majority" 

19 

20logger = logging.getLogger("uvicorn.error") 

21 

22logger.info("Connecting to mongo database: %s @ %s", MONGO_USERNAME, MONGO_HOST) 

23 

24 

25def connect_to_mongo(): 

26 attempt = 0 

27 while attempt < MAX_RETRIES: 

28 try: 

29 client = MongoClient(MONGO_CONN_STRING) 

30 client.admin.command("ping") 

31 logger.info("Database connected") 

32 return client 

33 except errors.ConnectionFailure as e: # pragma: no cover 

34 attempt += 1 

35 logger.info("Attempt %d/%d failed: %s", attempt, MAX_RETRIES, str(e)) 

36 if attempt < MAX_RETRIES: 

37 logger.info("New attempt in %d seconds...", RETRY_DELAY) 

38 time.sleep(RETRY_DELAY) 

39 logger.info("Cannot connect to database after %d retries", MAX_RETRIES) # pragma: no cover 

40 return None # pragma: no cover 

41 

42 

43client = connect_to_mongo() 

44async_client = AsyncMongoClient(MONGO_CONN_STRING) 

45# mongo_database = client.twinpad 

46 

47signals_database = client.signals 

48signals_async_database = async_client.signals 

49systems_database = client.systems 

50systems_async_database = async_client.systems 

51devices_states_database = client.devices_states 

52devices_states_async_database = async_client.devices_states 

53 

54_collection_cache: dict[Database, dict[str, Collection]] = {} 

55_async_collection_cache: dict[AsyncDatabase, dict[str, AsyncCollection]] = {} 

56 

57 

58def get_collection( 

59 database: Database, collection_name: str, create: bool = False, time_series: bool = False 

60) -> Collection | None: 

61 if database in _collection_cache: 

62 if collection_name in (database_cache := _collection_cache[database]): 

63 return database_cache[collection_name] 

64 else: 

65 _collection_cache[database] = {} 

66 

67 if collection_name in database.list_collection_names(): 

68 collection = database[collection_name] 

69 _collection_cache[database][collection_name] = collection 

70 return collection 

71 

72 if create: 

73 if time_series: 

74 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

75 else: 

76 collection = database.create_collection(collection_name) 

77 _collection_cache[database][collection_name] = collection 

78 return collection 

79 

80 return None 

81 

82 

83def get_collections_batch( 

84 database: Database, collection_names: list[str], create: bool = False, time_series: bool = False 

85) -> list[Collection | None]: 

86 """Fetches a batch of collections from an existing database. 

87 

88 :param database: MongoDB database in which the collections are located. 

89 :type database: Database 

90 :param collection_names: Names of the wanted collections. 

91 :type collection_names: list[str] 

92 :param create: Whether or not to create the collection if not found, defaults to False. 

93 :type create: bool, optional 

94 :param time_series: Whether or not the created collection should be a time series, defaults to False. 

95 :type time_series: bool, optional 

96 :return: A list of the requested collections. If :py:attr:`create` was false and the collection was not found, element is None. 

97 :rtype: list[Collection | None] 

98 """ 

99 collections = [] 

100 database_collection_names = None 

101 

102 if database in _collection_cache: 

103 database_cache = _collection_cache[database] 

104 else: 

105 _collection_cache[database] = database_cache = {} 

106 

107 for collection_name in collection_names: 

108 if collection_name in database_cache: 

109 collections.append(database_cache[collection_name]) 

110 continue 

111 

112 if database_collection_names is None: 

113 database_collection_names = database.list_collection_names() 

114 

115 if collection_name in database_collection_names: 

116 collection = database[collection_name] 

117 _collection_cache[database][collection_name] = collection 

118 collections.append(collection) 

119 continue 

120 

121 if create: 

122 if time_series: 

123 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

124 else: 

125 collection = database.create_collection(collection_name) 

126 _collection_cache[database][collection_name] = collection 

127 collections.append(collection) 

128 continue 

129 

130 collections.append(None) 

131 

132 return collections 

133 

134 

135async def get_async_collection( 

136 database: AsyncDatabase, collection_name: str, create: bool = False, time_series: bool = False 

137) -> AsyncCollection | None: 

138 if database in _async_collection_cache: 

139 if collection_name in (database_cache := _async_collection_cache[database]): 

140 return database_cache[collection_name] 

141 else: 

142 _async_collection_cache[database] = {} 

143 

144 if collection_name in (await database.list_collection_names()): 

145 collection = database[collection_name] 

146 _async_collection_cache[database][collection_name] = collection 

147 return collection 

148 

149 if create: 

150 if time_series: 

151 collection = await database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

152 else: 

153 collection = await database.create_collection(collection_name) 

154 _async_collection_cache[database][collection_name] = collection 

155 return collection 

156 

157 return None 

158 

159 

160def get_signal_collection(signal_id: str, create=False): 

161 return get_collection(signals_database, signal_id, create=create, time_series=True) 

162 

163 

164def get_signal_collections_batch(signal_ids: list[str], create=False): 

165 """Fetches signals collections in batch from the :py:const:`signals database <signals_database>`. 

166 

167 :param signal_ids: Signal IDs of the wanted signals. 

168 :type signal_ids: list[str] 

169 :param create: Whether or not to create the collection if it does not exist, it will be created as a time-series, defaults to False 

170 :type create: bool, optional 

171 :return: A list of the requested collections. If :py:attr:`create` was false and the collection was not found, element is None. 

172 :rtype: list[Collection | None] 

173 """ 

174 return get_collections_batch(signals_database, signal_ids, create=create, time_series=True) 

175 

176 

177def get_signals_ids_from_collection_names(): 

178 """ 

179 This is a low level function. 

180 """ 

181 return sorted([c for c in signals_database.list_collection_names() if not c.startswith("system.")]) 

182 

183 

184# def get_signals(): 

185# return [signal for signal in get_collection(systems_database, "signals", create=True).find({})] 

186 

187# def get_signals_ids(): 

188# return [signal['signal_id'] for signal in get_signals()] 

189 

190 

191def signal_datasize(): 

192 return signals_database.command("dbstats")["dataSize"]