Skip to content

Commit

Permalink
refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanProgrammer committed Sep 18, 2024
1 parent ba20c15 commit e4da1c4
Show file tree
Hide file tree
Showing 18 changed files with 114 additions and 104 deletions.
4 changes: 2 additions & 2 deletions smart-contracts/contracts/diamond/facets/Marketplace.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ contract Marketplace is

bytes32 bidId_ = getBidId(provider_, modelId_, nonce_);

addBid(bidId_, Bid(provider_, modelId_, pricePerSecond_, nonce_, uint128(block.timestamp), 0));
setBid(bidId_, Bid(provider_, modelId_, pricePerSecond_, nonce_, uint128(block.timestamp), 0));

addProviderBid(provider_, bidId_);
addModelBid(modelId_, bidId_);
Expand All @@ -108,9 +108,9 @@ contract Marketplace is
return bidId_;
}

/// @dev passing bidId and bid storage to avoid double storage access
function _deleteBid(bytes32 bidId_) private {
Bid storage bid = getBid(bidId_);

bid.deletedAt = uint128(block.timestamp);

removeProviderActiveBids(bid.provider, bidId_);
Expand Down
26 changes: 14 additions & 12 deletions smart-contracts/contracts/diamond/facets/ModelRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B

function __ModelRegistry_init() external initializer(MODEL_STORAGE_SLOT) {}

/// @notice Sets the minimum stake required for a model
function modelSetMinStake(uint256 modelMinimumStake_) external onlyOwner {
setModelMinimumStake(modelMinimumStake_);
emit ModelMinStakeUpdated(modelMinimumStake_);
function setModelMinimumStake(uint256 modelMinimumStake_) external onlyOwner {
_setModelMinimumStake(modelMinimumStake_);
emit ModelMinimumStakeSet(modelMinimumStake_);
}

/// @notice Registers or updates existing model
function modelRegister(
// TODO: it is not secure (frontrunning) to take the modelId as key
bytes32 modelId_,
bytes32 ipfsCID_,
uint256 fee_,
Expand All @@ -31,17 +31,22 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
string calldata name_,
string[] calldata tags_
) external {
if (!_ownerOrModelOwner(owner_)) {
if (!_isOwnerOrModelOwner(owner_)) {
// TODO: such that we cannon create a model with the owner as another address
// Do we need this check?
revert NotOwnerOrModelOwner();
}

Model memory model_ = models(modelId_);
// TODO: there is no way to decrease the stake
uint256 newStake_ = model_.stake + addStake_;
if (newStake_ < modelMinimumStake()) {
revert StakeTooLow();
}

getToken().safeTransferFrom(_msgSender(), address(this), addStake_);
if (addStake_ > 0) {
getToken().safeTransferFrom(_msgSender(), address(this), addStake_);
}

uint128 createdAt_ = model_.createdAt;
if (createdAt_ == 0) {
Expand All @@ -50,7 +55,7 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
setModelActive(modelId_, true);
createdAt_ = uint128(block.timestamp);
} else {
if (!_ownerOrModelOwner(model_.owner)) {
if (!_isOwnerOrModelOwner(model_.owner)) {
revert NotOwnerOrModelOwner();
}
if (model_.isDeleted) {
Expand All @@ -63,18 +68,15 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
emit ModelRegisteredUpdated(owner_, modelId_);
}

/// @notice Deregisters a model
function modelDeregister(bytes32 modelId_) external {
Model storage model = models(modelId_);

if (!isModelExists(modelId_)) {
revert ModelNotFound();
}

if (!_ownerOrModelOwner(model.owner)) {
if (!_isOwnerOrModelOwner(model.owner)) {
revert NotOwnerOrModelOwner();
}

if (!isModelActiveBidsEmpty(modelId_)) {
revert ModelHasActiveBids();
}
Expand All @@ -95,7 +97,7 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
return models(modelId_).createdAt != 0;
}

function _ownerOrModelOwner(address modelOwner_) internal view returns (bool) {
function _isOwnerOrModelOwner(address modelOwner_) internal view returns (bool) {
return _msgSender() == owner() || _msgSender() == modelOwner_;
}
}
12 changes: 9 additions & 3 deletions smart-contracts/contracts/diamond/facets/ProviderRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS
/// @param endpoint_ provider endpoint (host.com:1234)
function providerRegister(address providerAddress_, uint256 amount_, string calldata endpoint_) external {
if (!_ownerOrProvider(providerAddress_)) {
// TODO: such that we cannon create a provider with the owner as another address
// Do we need this check?
revert NotOwnerOrProvider();
}

Expand All @@ -36,7 +38,9 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS
revert StakeTooLow();
}

getToken().safeTransferFrom(_msgSender(), address(this), amount_);
if (amount_ > 0) {
getToken().safeTransferFrom(_msgSender(), address(this), amount_);
}

// if we add stake to an existing provider the limiter period is not reset
uint128 createdAt_ = provider_.createdAt;
Expand Down Expand Up @@ -73,11 +77,13 @@ contract ProviderRegistry is IProviderRegistry, OwnableDiamondStorage, ProviderS

Provider storage provider = providers(provider_);
uint256 withdrawable_ = _getWithdrawableStake(provider);

provider.stake -= withdrawable_;
provider.isDeleted = true;

getToken().safeTransfer(_msgSender(), withdrawable_);
if (withdrawable_ > 0) {
getToken().safeTransfer(_msgSender(), withdrawable_);
}

emit ProviderDeregistered(provider_);
}
Expand Down
22 changes: 12 additions & 10 deletions smart-contracts/contracts/diamond/facets/SessionRouter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ contract SessionRouter is
bytes calldata providerApproval_,
bytes calldata signature_
) external returns (bytes32) {
// should a user pass the bidId to compare with a providerApproval?
bytes32 bidId_ = _extractProviderApproval(providerApproval_);

Bid memory bid_ = getBid(bidId_);
Expand All @@ -64,16 +65,17 @@ contract SessionRouter is
if (!_isValidReceipt(bid_.provider, providerApproval_, signature_)) {
revert ProviderSignatureMismatch();
}
if (isApproved(providerApproval_)) {
if (isApprovalUsed(providerApproval_)) {
revert DuplicateApproval();
}
approve(providerApproval_);
setApprovalUsed(providerApproval_);

uint256 endsAt_ = whenSessionEnds(amount_, bid_.pricePerSecond, block.timestamp);
if (endsAt_ - block.timestamp < MIN_SESSION_DURATION) {
revert SessionTooShort();
}

// do we need to specify the amount in id?
bytes32 sessionId_ = getSessionId(_msgSender(), bid_.provider, amount_, incrementSessionNonce());
setSession(
sessionId_,
Expand All @@ -82,13 +84,13 @@ contract SessionRouter is
user: _msgSender(),
provider: bid_.provider,
modelId: bid_.modelId,
bidID: bidId_,
bidId: bidId_,
stake: amount_,
pricePerSecond: bid_.pricePerSecond,
closeoutReceipt: "",
closeoutType: 0,
providerWithdrawnAmount: 0,
openedAt: uint128(block.timestamp),
openedAt: block.timestamp,
endsAt: endsAt_,
closedAt: 0
})
Expand Down Expand Up @@ -132,7 +134,7 @@ contract SessionRouter is

// update session record
session.closeoutReceipt = receiptEncoded_; //TODO: remove that field in favor of tps and ttftMs
session.closedAt = uint128(block.timestamp);
session.closedAt = block.timestamp;

// calculate provider withdraw
uint256 providerWithdraw_;
Expand Down Expand Up @@ -200,19 +202,18 @@ contract SessionRouter is
// withdraw provider
_rewardProvider(session, providerWithdraw_, false);

// withdraw user
getToken().safeTransfer(session.user, userWithdraw_);
}

/// @notice allows provider to claim their funds
function claimProviderBalance(bytes32 sessionId_, uint256 amountToWithdraw_) external {
Session storage session = _getSession(sessionId_);
if (!_ownerOrProvider(session.provider)) {
revert NotOwnerOrProvider();
}
if (session.openedAt == 0) {
revert SessionNotFound();
}
if (!_ownerOrProvider(session.provider)) {
revert NotOwnerOrProvider();
}

uint256 withdrawableAmount = _getProviderClaimableBalance(session);
if (amountToWithdraw_ > withdrawableAmount) {
Expand All @@ -224,6 +225,7 @@ contract SessionRouter is

/// @notice deletes session from the history
function deleteHistory(bytes32 sessionId_) external {
// Why do we need this function?
Session storage session = _getSession(sessionId_);
if (!_ownerOrUser(session.user)) {
revert NotOwnerOrUser();
Expand Down Expand Up @@ -294,7 +296,7 @@ contract SessionRouter is
/// @dev parameters should be the same as in Ethereum L1 Distribution contract
/// @dev at address 0x47176B2Af9885dC6C4575d4eFd63895f7Aaa4790
/// @dev call 'Distribution.pools(3)' where '3' is a poolId
function setPoolConfig(uint256 index, Pool calldata pool) public onlyOwner {
function setPoolConfig(uint256 index, Pool calldata pool) external onlyOwner {
if (index >= getPools().length) {
revert PoolIndexOutOfBounds();
}
Expand Down
2 changes: 1 addition & 1 deletion smart-contracts/contracts/diamond/storages/BidStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ contract BidStorage is IBidStorage {
_getBidStorage().modelBids[modelId].push(bidId);
}

function addBid(bytes32 bidId, Bid memory bid) internal {
function setBid(bytes32 bidId, Bid memory bid) internal {
_getBidStorage().bids[bidId] = bid;
}

Expand Down
4 changes: 2 additions & 2 deletions smart-contracts/contracts/diamond/storages/ModelStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ contract ModelStorage is IModelStorage {
_getModelStorage().models[modelId] = model;
}

function setModelMinimumStake(uint256 _modelMinimumStake) internal {
_getModelStorage().modelMinimumStake = _modelMinimumStake;
function _setModelMinimumStake(uint256 modelMinimumStake_) internal {
_getModelStorage().modelMinimumStake = modelMinimumStake_;
}

function models(bytes32 id) internal view returns (Model storage) {
Expand Down
4 changes: 2 additions & 2 deletions smart-contracts/contracts/diamond/storages/SessionStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ contract SessionStorage is ISessionStorage {
return _getSessionStorage().sessionNonce++;
}

function isApproved(bytes memory approval) internal view returns (bool) {
function isApprovalUsed(bytes memory approval) internal view returns (bool) {
return _getSessionStorage().isApprovalUsed[approval];
}

function approve(bytes memory approval) internal {
function setApprovalUsed(bytes memory approval) internal {
_getSessionStorage().isApprovalUsed[approval] = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {IModelStorage} from "../storage/IModelStorage.sol";
interface IModelRegistry is IModelStorage {
event ModelRegisteredUpdated(address indexed owner, bytes32 indexed modelId);
event ModelDeregistered(address indexed owner, bytes32 indexed modelId);
event ModelMinStakeUpdated(uint256 newStake);
event ModelMinimumStakeSet(uint256 newStake);

error ModelNotFound();
error StakeTooLow();
Expand All @@ -15,7 +15,7 @@ interface IModelRegistry is IModelStorage {

function __ModelRegistry_init() external;

function modelSetMinStake(uint256 modelMinimumStake_) external;
function setModelMinimumStake(uint256 modelMinimumStake_) external;

function modelRegister(
bytes32 modelId_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ interface ISessionStorage {
address user;
address provider;
bytes32 modelId;
bytes32 bidID;
bytes32 bidId;
uint256 stake;
uint256 pricePerSecond;
bytes closeoutReceipt;
uint256 closeoutType;
uint256 closeoutType; // use enum ??
// amount of funds that was already withdrawn by provider (we allow to withdraw for the previous day)
uint256 providerWithdrawnAmount;
uint256 openedAt;
Expand Down
2 changes: 1 addition & 1 deletion smart-contracts/test/diamond/Marketplace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ describe('Marketplace', () => {
user: SECOND,
provider: bid.provider,
modelId: bid.modelId,
bidID: bid.id,
bidId: bid.id,
stake: (totalCost * totalSupply) / todaysBudget,
};

Expand Down
10 changes: 5 additions & 5 deletions smart-contracts/test/diamond/ModelRegistry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ describe('Model registry', () => {
user: SECOND,
provider: bid.provider,
modelId: bid.modelId,
bidID: bid.id,
bidId: bid.id,
stake: (totalCost * totalSupply) / todaysBudget,
};

Expand Down Expand Up @@ -270,7 +270,7 @@ describe('Model registry', () => {

it('Should error when registering with insufficient stake', async () => {
const minStake = 100n;
await modelRegistry.modelSetMinStake(minStake);
await modelRegistry.setModelMinimumStake(minStake);

await expect(
modelRegistry.modelRegister(randomBytes32(), randomBytes32(), 0n, 0n, OWNER, 'a', []),
Expand Down Expand Up @@ -472,14 +472,14 @@ describe('Model registry', () => {
describe('Min stake', () => {
it('Should set min stake', async () => {
const minStake = 100n;
await expect(modelRegistry.modelSetMinStake(minStake))
.to.emit(modelRegistry, 'ModelMinStakeUpdated')
await expect(modelRegistry.setModelMinimumStake(minStake))
.to.emit(modelRegistry, 'ModelMinimumStakeSet')
.withArgs(minStake);

expect(await modelRegistry.modelMinimumStake()).eq(minStake);
});
it('Should error when not owner is setting min stake', async () => {
await expect(modelRegistry.connect(THIRD).modelSetMinStake(0)).to.revertedWith(
await expect(modelRegistry.connect(THIRD).setModelMinimumStake(0)).to.revertedWith(
'OwnableDiamondStorage: not an owner',
);
});
Expand Down
2 changes: 1 addition & 1 deletion smart-contracts/test/diamond/ProviderRegistry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ describe('Provider registry', () => {
user: SECOND,
provider: bid.provider,
modelId: bid.modelId,
bidID: bid.id,
bidId: bid.id,
stake: (totalCost * totalSupply) / todaysBudget,
};

Expand Down
Loading

0 comments on commit e4da1c4

Please sign in to comment.