Skip to content

Commit

Permalink
chore (backend): upgrade to pydantic 2
Browse files Browse the repository at this point in the history
  • Loading branch information
brownben committed Dec 6, 2024
1 parent d1c25ad commit 1261f5d
Show file tree
Hide file tree
Showing 19 changed files with 182 additions and 123 deletions.
4 changes: 2 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"uvloop~=0.21.0; platform_system != 'Windows'",

# database
"piccolo==0.119.0",
"piccolo==1.22.0",
"asyncpg~=0.30.0",

# results file parsing
Expand All @@ -34,7 +34,7 @@ dependencies = [
]
optional-dependencies.dev = [
"aiosqlite~=0.20.0",
"coverage~=7.6.8",
"coverage~=7.6.9",
"mypy~=1.13.0",
"respx~=0.21.1",
"ruff~=0.8.2",
Expand Down
12 changes: 6 additions & 6 deletions backend/src/database/competitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def as_competitor(record: dict[str, Any] | None) -> Competitor | None:
if not record:
return None

return Competitor.parse_obj(record)
return Competitor.model_validate(record)


class Competitors:
Expand All @@ -44,14 +44,14 @@ async def get_pool_count() -> int:

@staticmethod
async def create(competitor: NewCompetitor) -> int:
new_row = CompetitorTable(**competitor.dict())
new_row = CompetitorTable(**competitor.model_dump())
await new_row.save().run()
return new_row.id

@staticmethod
async def get_all() -> Iterable[Competitor]:
return (
Competitor.parse_obj(competitor)
Competitor.model_validate(competitor)
for competitor in await CompetitorTable.select(*competitor_fields)
.order_by(CompetitorTable.id)
.run()
Expand All @@ -60,7 +60,7 @@ async def get_all() -> Iterable[Competitor]:
@staticmethod
async def get_by_pool(competitor_pool_name: str) -> Iterable[Competitor]:
return (
Competitor.parse_obj(competitor)
Competitor.model_validate(competitor)
for competitor in await CompetitorTable.select(*competitor_fields)
.where(CompetitorTable.competitor_pool == competitor_pool_name)
.order_by(CompetitorTable.id)
Expand Down Expand Up @@ -98,7 +98,7 @@ async def update_by_id(competitor_id: int, competitor: NewCompetitor) -> bool:
if not existing_competitor:
return False

for key, value in competitor.dict().items():
for key, value in competitor.model_dump().items():
setattr(existing_competitor, key, value)

await existing_competitor.save().run()
Expand All @@ -125,7 +125,7 @@ async def merge(
@staticmethod
async def search(query: str) -> Iterable[Competitor]:
return (
Competitor.parse_obj(competitor)
Competitor.model_validate(competitor)
for competitor in await CompetitorTable.select(*competitor_fields)
.where(CompetitorTable.name.ilike(f"%{query}%"))
.limit(12)
Expand Down
30 changes: 16 additions & 14 deletions backend/src/database/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def as_event(record: dict[str, Any]) -> EventWithUploadKey | None:
if not record:
return None

return EventWithUploadKey.parse_obj(record)
return EventWithUploadKey.model_validate(record)


def generate_upload_key() -> str:
Expand All @@ -72,7 +72,7 @@ async def create(event: EventCreationRequest, league: League) -> None:
date=event.date,
organiser=event.organiser,
part_of=event.part_of,
website=event.website,
website=str(event.website),
more_information=event.more_information,
results_links=event.results_links,
upload_key=generate_upload_key(),
Expand All @@ -96,7 +96,7 @@ async def create(event: EventCreationRequest, league: League) -> None:
@staticmethod
async def get_all() -> Iterable[EventWithUploadKey]:
return (
EventWithUploadKey.parse_obj(event)
EventWithUploadKey.model_validate(event)
for event in await EventTable.select(*event_fields)
.order_by(EventTable.date, ascending=False)
.output(load_json=True)
Expand Down Expand Up @@ -157,8 +157,10 @@ async def update(
if not existing_event:
return "no-event"

for key, value in event.dict().items():
if key != "id":
for key, value in event.model_dump().items():
if key == "website":
setattr(existing_event, key, str(value))
elif key != "id":
setattr(existing_event, key, value)

await existing_event.save().run()
Expand All @@ -172,7 +174,7 @@ async def delete_by_id(id: str) -> None:
@staticmethod
async def get_latest_with_results(limit: int = 12) -> Iterable[EventWithUploadKey]:
return (
EventWithUploadKey.parse_obj(event)
EventWithUploadKey.model_validate(event)
for event in await EventTable.select(*event_fields)
.where(EventTable.results_uploaded == True)
.order_by(EventTable.date, ascending=False)
Expand All @@ -184,7 +186,7 @@ async def get_latest_with_results(limit: int = 12) -> Iterable[EventWithUploadKe
@staticmethod
async def get_allowing_results_submission() -> Iterable[EventWithUploadKey]:
return (
EventWithUploadKey.parse_obj(event)
EventWithUploadKey.model_validate(event)
for event in await EventTable.select(*event_fields)
.where(EventTable.allow_user_submitted_results == True)
.output(load_json=True)
Expand All @@ -196,7 +198,7 @@ async def get_by_competitor_pool(
competitor_pool: str,
) -> Iterable[EventWithUploadKey]:
return (
EventWithUploadKey.parse_obj(event)
EventWithUploadKey.model_validate(event)
for event in await EventTable.select(*event_fields)
.where(EventTable.competitor_pool == competitor_pool)
.order_by(EventTable.date)
Expand All @@ -211,7 +213,9 @@ async def update_results_links(
await (
EventTable.update(
{
EventTable.results_links: results_links,
EventTable.results_links: {
k: str(v) for k, v in results_links.items()
},
EventTable.results_uploaded: True,
EventTable.results_uploaded_time: datetime.datetime.now(),
}
Expand All @@ -224,15 +228,13 @@ async def update_results_links(
async def get_by_league(
league: str,
) -> Iterable[EventWithLeagueDetailsAndUploadKey]:
# TODO: piccolo output doesn't have correct overload yet for `order_by`

return (
EventWithLeagueDetailsAndUploadKey(
**event["event"],
results_links=event["results_links"],
group=event["group"],
)
for event in await LeagueEventTable.select( # type: ignore
for event in await LeagueEventTable.select(
*LeagueEventTable.event.all_columns(exclude=[EventTable.results_links]),
LeagueEventTable.event.results_links.as_alias("results_links"),
LeagueEventTable.league_group.name.as_alias("group"),
Expand All @@ -246,7 +248,7 @@ async def get_by_league(
@staticmethod
async def search(query: str) -> Iterable[EventWithUploadKey]:
return (
EventWithUploadKey.parse_obj(event)
EventWithUploadKey.model_validate(event)
for event in await EventTable.select(*event_fields)
.where(EventTable.name.ilike(f"%{query}%"))
.order_by(EventTable.date, ascending=False)
Expand Down Expand Up @@ -278,7 +280,7 @@ async def create(event: LeagueEventCreationRequest) -> None:
@staticmethod
async def get_by_league_with_results(league: str) -> Iterable[LeagueEvent]:
return [
LeagueEvent.parse_obj(event)
LeagueEvent.model_validate(event)
for event in await LeagueEventTable.select(
*league_event_fields,
LeagueEventTable.event.name.as_alias("event_name"),
Expand Down
29 changes: 16 additions & 13 deletions backend/src/database/leagues.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def as_league(record: dict[str, Any] | None) -> League | None:
if not record:
return None

return League.parse_obj(record)
return League.model_validate(record)


def as_league_class(record: dict[str, Any] | None) -> LeagueClass | None:
if not record:
return None

return LeagueClass.parse_obj(record)
return LeagueClass.model_validate(record)


class Leagues:
Expand All @@ -59,7 +59,7 @@ async def create(league: League) -> bool:
tagline=league.tagline,
year=league.year,
coordinator=league.coordinator,
website=league.website,
website=str(league.website),
more_information=league.more_information,
visible=league.visible,
scoring_method=league.scoring_method,
Expand Down Expand Up @@ -89,8 +89,11 @@ async def update(name: str, league: League) -> bool:
if not existing_league:
return False

for key, value in league.dict().items():
setattr(existing_league, key, value)
for key, value in league.model_dump().items():
if key == "website":
setattr(existing_league, key, str(value))
else:
setattr(existing_league, key, value)

await existing_league.save().run()

Expand All @@ -99,7 +102,7 @@ async def update(name: str, league: League) -> bool:
@staticmethod
async def get_all() -> Iterable[League]:
return (
League.parse_obj(league)
League.model_validate(league)
for league in await LeagueTable.select(*league_fields)
.order_by(LeagueTable.year, ascending=False)
.order_by(LeagueTable.name)
Expand All @@ -119,7 +122,7 @@ async def get_by_name(name: str) -> League | None:
@staticmethod
async def get_by_competitor_pool(competitor_pool: str) -> Iterable[League]:
return (
League.parse_obj(league)
League.model_validate(league)
for league in await LeagueTable.select(*league_fields)
.where(LeagueTable.competitor_pool == competitor_pool)
.order_by(LeagueTable.year, ascending=False)
Expand All @@ -133,7 +136,7 @@ async def delete_by_name(name: str) -> None:
@staticmethod
async def search(query: str) -> Iterable[League]:
return (
League.parse_obj(league)
League.model_validate(league)
for league in await LeagueTable.select(*league_fields)
.where(
LeagueTable.name.ilike(f"%{query}%")
Expand Down Expand Up @@ -177,7 +180,7 @@ async def get_by_name(league: str, cls: str) -> LeagueClass | None:
@staticmethod
async def get_by_league(league: str) -> Iterable[LeagueClass]:
return (
LeagueClass.parse_obj(league)
LeagueClass.model_validate(league)
for league in await LeagueClassTable.select(*league_class_fields)
.where(LeagueClassTable.league == league)
.order_by(LeagueClassTable.name)
Expand All @@ -197,7 +200,7 @@ async def update(league: str, name: str, cls: LeagueClass) -> bool:
if not existing_class:
return False

for key, value in cls.dict().items():
for key, value in cls.model_dump().items():
setattr(existing_class, key, value)

await existing_class.save().run()
Expand Down Expand Up @@ -229,7 +232,7 @@ async def create(league_group: LeagueGroup) -> None:
@staticmethod
async def get_by_league(league: str) -> list[LeagueGroup]:
return [
LeagueGroup.parse_obj(group)
LeagueGroup.model_validate(group)
for group in await LeagueGroupTable.select(*LeagueGroupTable.all_columns())
.where(LeagueGroupTable.league == league)
.order_by(LeagueGroupTable.name)
Expand All @@ -247,7 +250,7 @@ async def get_by_name(league: str, group: str) -> LeagueGroup | None:
)

if database_result:
return LeagueGroup.parse_obj(database_result)
return LeagueGroup.model_validate(database_result)
else:
return None

Expand All @@ -274,7 +277,7 @@ async def update(league: str, name: str, group: LeagueGroup) -> bool:
if not existing_group:
return False

for key, value in group.dict().items():
for key, value in group.model_dump().items():
if key != "id":
setattr(existing_group, key, value)

Expand Down
12 changes: 6 additions & 6 deletions backend/src/database/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ async def get_by_id(id: int) -> Result | None:
.run()
)

return Result.parse_obj(database_result)
return Result.model_validate(database_result)

@staticmethod
async def get_by_competitor(competitor: int) -> Iterable[ResultWithEventName]:
return (
ResultWithEventName.parse_obj(result)
ResultWithEventName.model_validate(result)
for result in await ResultTable.select(
*results_fields, ResultTable.event.name.as_alias("event_name")
)
Expand All @@ -76,7 +76,7 @@ async def get_by_competitor(competitor: int) -> Iterable[ResultWithEventName]:
@staticmethod
async def get_by_event_and_course(event: str, course: str) -> Iterable[Result]:
return (
Result.parse_obj(result)
Result.model_validate(result)
for result in await ResultTable.select(*results_fields)
.where(ResultTable.visible == True)
.where(ResultTable.event == event)
Expand All @@ -91,7 +91,7 @@ async def get_by_event_and_courses(
event: str, courses: Iterable[str]
) -> Iterable[Result]:
return (
Result.parse_obj(result)
Result.model_validate(result)
for result in await ResultTable.select(*results_fields)
.where(ResultTable.visible == True)
.where(ResultTable.event == event)
Expand All @@ -104,7 +104,7 @@ async def get_by_event_and_courses(
@staticmethod
async def get_by_event(event: str) -> Iterator[Result]:
return (
Result.parse_obj(result)
Result.model_validate(result)
for result in await ResultTable.select(*results_fields)
.where(ResultTable.visible == True)
.where(ResultTable.event == event)
Expand All @@ -116,7 +116,7 @@ async def get_by_event(event: str) -> Iterator[Result]:
@staticmethod
async def get_event_results(event: str) -> Iterable[EventResult]:
return (
EventResult.parse_obj(result)
EventResult.model_validate(result)
for result in await ResultTable.select(
ResultTable.id,
ResultTable.time,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/routes/competitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def get_overview_for_competitor(
if not competitor:
raise HTTP_404(f"Couldn't find competitor with the id `{id}`")

return CompetitorOverview(**competitor.dict(), results=results, league=league)
return CompetitorOverview(**competitor.model_dump(), results=results, league=league)


@router.post("/merge", response_model=Message)
Expand Down
2 changes: 1 addition & 1 deletion backend/src/routes/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def get_results_for_event(
raise HTTP_404(f"Couldn't find event with the id `{id}`")

return EventWithResults(
**event.dict(),
**event.model_dump(),
results=assign_position_based_on_time(results or []),
league=league,
)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/routes/leagues.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
def competitor_to_league_result(
competitor: Competitor, number_of_events: int = 0
) -> LeagueResult:
league_result = LeagueResult.parse_obj(competitor)
league_result = LeagueResult.model_validate(competitor.model_dump())
league_result.points = [None] * number_of_events
return league_result

Expand Down Expand Up @@ -87,7 +87,7 @@ async def get_league(name: str) -> LeagueOverviewAuthenticated:
raise HTTP_404(f"Couldn't find league with name `{name}`")

return LeagueOverviewAuthenticated(
**result.dict(),
**result.model_dump(),
classes=classes,
events=events,
groups=groups,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/routes/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def create_manual_result(
authentication: bool = Depends(require_authentication),
) -> Message:
result = ResultBeforeDatabase(
**request.dict(exclude={"time"}),
**request.model_dump(exclude={"time"}),
time=parse_time(request.time),
file_points=request.points,
)
Expand Down
Loading

0 comments on commit 1261f5d

Please sign in to comment.