Skip to content

Commit

Permalink
SQLalchemy upgrade - step one (#2979)
Browse files Browse the repository at this point in the history
* Remove legacy AclBaseQuery to align with SQLALchemy 2.x

* Remove use of query.get() method to align with SQLALchemy 2.x

* Add explcit db session.add

* Remove db model object constructors

* Add relationships to User model

* Add required BaseQuery

* Disable linting for unmaintained contrib

* Renamed BaseQuery to Query

* fix test

* formatting

* Add future flag to create_engine()

---------

Co-authored-by: Janosch <[email protected]>
  • Loading branch information
berggren and jkppr authored Dec 20, 2023
1 parent ee37d03 commit 07b6b2e
Show file tree
Hide file tree
Showing 35 changed files with 242 additions and 902 deletions.
6 changes: 4 additions & 2 deletions contrib/gcs_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Google Cloud Storage importer."""
# Unmaintained contrib. Skip linting this file.
# pylint: skip-file

import argparse
import time
Expand Down Expand Up @@ -80,12 +82,12 @@ def setup_sketch(timeline_name, index_name, username, sketch_id=None):
(tuple) sketch ID and timeline ID as integers
"""
with app.app_context():
user = User.get_or_create(username=username)
user = User.get_or_create(username=username, name=username)
sketch = None

if sketch_id:
try:
sketch = Sketch.query.get_with_acl(sketch_id, user=user)
sketch = Sketch.get_with_acl(sketch_id, user=user)
logger.info(
"Using existing sketch: {} ({})".format(sketch.name, sketch.id)
)
Expand Down
36 changes: 18 additions & 18 deletions timesketch/api/v1/resources/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def get(self, sketch_id, aggregation_id): # pylint: disable=unused-argument
Returns:
JSON with aggregation results
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
if not sketch.has_permission(current_user, "read"):
abort(
HTTP_STATUS_CODE_FORBIDDEN,
"User does not have read access controls on sketch.",
)
aggregation = Aggregation.query.get(aggregation_id)
aggregation = Aggregation.get_by_id(aggregation_id)

# Check that this aggregation belongs to the sketch
if aggregation.sketch_id != sketch.id:
Expand Down Expand Up @@ -111,7 +111,7 @@ def post(self, sketch_id, aggregation_id):
if not form:
abort(HTTP_STATUS_CODE_BAD_REQUEST, "Unable to validate form data.")

sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
if not sketch.has_permission(current_user, "write"):
Expand All @@ -120,7 +120,7 @@ def post(self, sketch_id, aggregation_id):
"User does not have write access controls on sketch.",
)

aggregation = Aggregation.query.get(aggregation_id)
aggregation = Aggregation.get_by_id(aggregation_id)
if not aggregation:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No aggregation found with this ID.")

Expand Down Expand Up @@ -169,11 +169,11 @@ def delete(self, sketch_id, aggregation_id):
group_id: Integer primary key for an aggregation group database
model.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

aggregation = Aggregation.query.get(aggregation_id)
aggregation = Aggregation.get_by_id(aggregation_id)
if not aggregation:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No aggregation found with this ID.")

Expand Down Expand Up @@ -274,8 +274,8 @@ def get(self, sketch_id, group_id):
sketch_id: Integer primary key for a sketch database model.
group_id: Integer primary key for an aggregation group database
"""
sketch = Sketch.query.get_with_acl(sketch_id)
group = AggregationGroup.query.get(group_id)
sketch = Sketch.get_with_acl(sketch_id)
group = AggregationGroup.get_by_id(group_id)

if not group:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No Group found with this ID.")
Expand Down Expand Up @@ -321,8 +321,8 @@ def post(self, sketch_id, group_id):
group_id: Integer primary key for an aggregation group database
model.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
group = AggregationGroup.query.get(group_id)
sketch = Sketch.get_with_acl(sketch_id)
group = AggregationGroup.get_by_id(group_id)
if not group:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No Group found with this ID.")

Expand Down Expand Up @@ -362,7 +362,7 @@ def post(self, sketch_id, group_id):
aggregations = []

for agg_id in agg_ids:
aggregation = Aggregation.query.get(agg_id)
aggregation = Aggregation.get_by_id(agg_id)
if not aggregation:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
Expand All @@ -386,8 +386,8 @@ def delete(self, sketch_id, group_id):
group_id: Integer primary key for an aggregation group database
model.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
group = AggregationGroup.query.get(group_id)
sketch = Sketch.get_with_acl(sketch_id)
group = AggregationGroup.get_by_id(group_id)

if not group:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No Group found with this ID.")
Expand Down Expand Up @@ -442,7 +442,7 @@ def post(self, sketch_id):
"Not able to run aggregation, unable to validate form data.",
)

sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -579,7 +579,7 @@ def get(self, sketch_id):
Returns:
Views in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -656,7 +656,7 @@ def post(self, sketch_id):
if not form:
abort(HTTP_STATUS_CODE_BAD_REQUEST, "Unable to validate form data.")

sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -689,7 +689,7 @@ def get(self, sketch_id):
Returns:
Views in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -731,7 +731,7 @@ def post(self, sketch_id):
Returns:
An aggregation in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down
16 changes: 8 additions & 8 deletions timesketch/api/v1/resources/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get(self, sketch_id, timeline_id):
Returns:
An analysis in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand All @@ -70,7 +70,7 @@ def get(self, sketch_id, timeline_id):
HTTP_STATUS_CODE_FORBIDDEN, "User does not have read access to sketch"
)

timeline = Timeline.query.get(timeline_id)
timeline = Timeline.get_by_id(timeline_id)
if not timeline:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No timeline found with this ID.")

Expand All @@ -96,7 +96,7 @@ def get(self, sketch_id):
Returns:
A analyzer session in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)

if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
Expand Down Expand Up @@ -152,7 +152,7 @@ def get(self, sketch_id, session_id):
Returns:
A analyzer session in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)

if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
Expand All @@ -162,7 +162,7 @@ def get(self, sketch_id, session_id):
HTTP_STATUS_CODE_FORBIDDEN, "User does not have read access to sketch"
)

analysis_session = AnalysisSession.query.get(session_id)
analysis_session = AnalysisSession.get_by_id(session_id)

return self.to_json(analysis_session)

Expand All @@ -181,7 +181,7 @@ def get(self, sketch_id):
* description: Description of the analyzer provided in the class
* is_multi: Boolean indicating if the analyzer is a multi analyzer
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
if not sketch.has_permission(current_user, "read"):
Expand Down Expand Up @@ -216,7 +216,7 @@ def post(self, sketch_id):
Returns:
A string with the response from running the analyzer.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -284,7 +284,7 @@ def post(self, sketch_id):
# TODO: Change to run on Timeline instead of Index
sessions = []
for timeline_id in timeline_ids:
timeline = Timeline.query.get(timeline_id)
timeline = Timeline.get_by_id(timeline_id)
if not timeline:
continue
if not timeline.status[0].status == "ready":
Expand Down
8 changes: 4 additions & 4 deletions timesketch/api/v1/resources/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def get(self, sketch_id):
A sketch in JSON (instance of flask.wrappers.Response)
"""
if current_user.admin:
sketch = Sketch.query.get(sketch_id)
sketch = Sketch.get_by_id(sketch_id)
else:
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)

if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
Expand Down Expand Up @@ -121,9 +121,9 @@ def post(self, sketch_id):
A sketch in JSON (instance of flask.wrappers.Response)
"""
if current_user.admin:
sketch = Sketch.query.get(sketch_id)
sketch = Sketch.get_by_id(sketch_id)
else:
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)

if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")
Expand Down
7 changes: 4 additions & 3 deletions timesketch/api/v1/resources/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get(self, sketch_id):
Returns:
An analysis in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand All @@ -87,7 +87,7 @@ def post(self, sketch_id):
Returns:
A HTTP 200 if the attribute is successfully added or modified.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -193,7 +193,7 @@ def delete(self, sketch_id):
Returns:
A HTTP response code.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -231,6 +231,7 @@ def delete(self, sketch_id):
for value in attribute.values:
attribute.values.remove(value)
sketch.attributes.remove(attribute)
db_session.add(sketch)
db_session.commit()

return HTTP_STATUS_CODE_OK
Expand Down
2 changes: 1 addition & 1 deletion timesketch/api/v1/resources/datafinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def post(self, sketch_id):
Returns:
A list of JSON representations of the data sources.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down
14 changes: 7 additions & 7 deletions timesketch/api/v1/resources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get(self, sketch_id):
Returns:
A list of JSON representations of the data sources.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand Down Expand Up @@ -85,7 +85,7 @@ def post(self, sketch_id):
Returns:
A datasource in JSON (instance of flask.wrappers.Response)
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand All @@ -112,7 +112,7 @@ def post(self, sketch_id):
"Unable to create a data source without a timeline " "identifier.",
)

timeline = Timeline.query.get(timeline_id)
timeline = Timeline.get_by_id(timeline_id)
if not timeline:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No timeline found with this ID.")

Expand Down Expand Up @@ -150,7 +150,7 @@ def _verify_sketch_and_datasource(self, sketch_id, datasource_id):
This function aborts if the ACLs on the sketch are not sufficient and
the data source does not belong to the sketch in question.
"""
sketch = Sketch.query.get_with_acl(sketch_id)
sketch = Sketch.get_with_acl(sketch_id)
if not sketch:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID.")

Expand All @@ -160,7 +160,7 @@ def _verify_sketch_and_datasource(self, sketch_id, datasource_id):
"Unable to fetch data sources from an archived sketch.",
)

data_source = DataSource.query.get(datasource_id)
data_source = DataSource.get_by_id(datasource_id)
if not data_source:
abort(HTTP_STATUS_CODE_NOT_FOUND, "No DataSource found with this ID.")

Expand All @@ -182,7 +182,7 @@ def get(self, sketch_id, datasource_id):
A JSON representation of the data source.
"""
self._verify_sketch_and_datasource(sketch_id, datasource_id)
data_source = DataSource.query.get(datasource_id)
data_source = DataSource.get_by_id(datasource_id)
return self.to_json(data_source)

@login_required
Expand All @@ -198,7 +198,7 @@ def post(self, sketch_id, datasource_id):
"""
self._verify_sketch_and_datasource(sketch_id, datasource_id)

data_source = DataSource.query.get(datasource_id)
data_source = DataSource.get_by_id(datasource_id)
changed = False

form = request.json
Expand Down
Loading

0 comments on commit 07b6b2e

Please sign in to comment.