Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audit review #32

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions smart-contracts/contracts/diamond/facets/Marketplace.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,11 @@ contract Marketplace is
emit MaretplaceFeeUpdated(bidFee_);
}

function setMinMaxBidPricePerSecond(
uint256 bidMinPricePerSecond_,
uint256 bidMaxPricePerSecond_
) public onlyOwner {
function setMinMaxBidPricePerSecond(uint256 bidMinPricePerSecond_, uint256 bidMaxPricePerSecond_) public onlyOwner {
if (bidMinPricePerSecond_ == 0) {
revert MarketplaceBidMinPricePerSecondIsZero();
}

if (bidMinPricePerSecond_ > bidMaxPricePerSecond_) {
revert MarketplaceBidMinPricePerSecondIsInvalid();
}
Expand All @@ -74,7 +71,9 @@ contract Marketplace is
BidsStorage storage bidsStorage = _getBidsStorage();
MarketStorage storage marketStorage = _getMarketStorage();

if (pricePerSecond_ < marketStorage.bidMinPricePerSecond || pricePerSecond_ > marketStorage.bidMaxPricePerSecond) {
if (
pricePerSecond_ < marketStorage.bidMinPricePerSecond || pricePerSecond_ > marketStorage.bidMaxPricePerSecond
) {
revert MarketplaceBidPricePerSecondInvalid();
}

Expand Down
14 changes: 10 additions & 4 deletions smart-contracts/contracts/diamond/facets/ModelRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
}

function modelRegister(
bytes32 modelId_,
bytes32 baseModelId_,
bytes32 ipfsCID_,
uint256 fee_,
uint256 amount_,
string calldata name_,
string[] memory tags_
) external {
ModelsStorage storage modelsStorage = _getModelsStorage();

bytes32 modelId_ = getModelId(_msgSender(), baseModelId_);
Model storage model = modelsStorage.models[modelId_];

uint256 newStake_ = model.stake + amount_;
Expand All @@ -51,8 +53,6 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B

model.createdAt = uint128(block.timestamp);
model.owner = _msgSender();
} else {
_onlyAccount(model.owner);
}

model.stake = newStake_;
Expand All @@ -67,8 +67,10 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B
emit ModelRegisteredUpdated(_msgSender(), modelId_);
}

function modelDeregister(bytes32 modelId_) external {
function modelDeregister(bytes32 baseModelId_) external {
ModelsStorage storage modelsStorage = _getModelsStorage();

bytes32 modelId_ = getModelId(_msgSender(), baseModelId_);
Model storage model = modelsStorage.models[modelId_];

_onlyAccount(model.owner);
Expand All @@ -91,4 +93,8 @@ contract ModelRegistry is IModelRegistry, OwnableDiamondStorage, ModelStorage, B

emit ModelDeregistered(model.owner, modelId_);
}

function getModelId(address account_, bytes32 baseModelId_) public pure returns (bytes32) {
return keccak256(abi.encodePacked(account_, baseModelId_));
}
}
43 changes: 29 additions & 14 deletions smart-contracts/contracts/diamond/facets/SessionRouter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import {LibSD} from "../../libs/LibSD.sol";

import {ISessionRouter} from "../../interfaces/facets/ISessionRouter.sol";

import "hardhat/console.sol";

contract SessionRouter is
ISessionRouter,
OwnableDiamondStorage,
Expand Down Expand Up @@ -258,36 +256,46 @@ contract SessionRouter is
return (sessionEnd_ - session.openedAt) * bid.pricePerSecond - withdrawnAmount;
}

function _getProviderOnHoldAmount(Session storage session, Bid storage bid) private view returns (uint256) {
uint128 startOfClosedAt = startOfTheDay(session.closedAt);
if (block.timestamp >= startOfClosedAt + 1 days) {
return 0;
}

// `closedAt` - latest timestamp, cause `endsAt` bigger then `closedAt`
// Lock the provider's tokens for the current day.
// Withdrawal is allowed after a day after `startOfTheDay(session.closedAt)`.
return (session.closedAt - startOfClosedAt.max(session.openedAt)) * bid.pricePerSecond;
}

function _rewardProviderAfterClose(bool noDispute_, Session storage session, Bid storage bid) internal {
uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp));
bool isClosingLate_ = uint128(block.timestamp) > session.endsAt;
bool isClosingLate_ = session.closedAt >= session.endsAt;

uint256 providerAmountToWithdraw_ = _getProviderRewards(session, bid, true);
uint256 providerOnHoldAmount = 0;
// Enter when the user has a dispute AND closing early
if (!noDispute_ && !isClosingLate_) {
providerOnHoldAmount =
(session.endsAt.min(session.closedAt) - startOfToday_.max(session.openedAt)) *
bid.pricePerSecond;
providerOnHoldAmount = _getProviderOnHoldAmount(session, bid);
}
providerAmountToWithdraw_ -= providerOnHoldAmount;

_claimForProvider(session, providerAmountToWithdraw_);
}

function _rewardUserAfterClose(Session storage session, Bid storage bid) private {
uint128 startOfToday_ = startOfTheDay(uint128(block.timestamp));
bool isClosingLate_ = uint128(block.timestamp) > session.endsAt;
uint128 startOfClosedAt_ = startOfTheDay(session.closedAt);
bool isClosingLate_ = session.closedAt >= session.endsAt;

uint256 userStakeToProvider = session.isDirectPaymentFromUser ? _getProviderRewards(session, bid, false) : 0;
uint256 userStake = session.stake - userStakeToProvider;
uint256 userStakeToLock_ = 0;
if (!isClosingLate_) {
uint256 userDuration_ = session.endsAt.min(session.closedAt) - session.openedAt.max(startOfToday_);
uint256 userDuration_ = session.endsAt.min(session.closedAt) - session.openedAt.max(startOfClosedAt_);
uint256 userInitialLock_ = userDuration_ * bid.pricePerSecond;
userStakeToLock_ = userStake.min(stipendToStake(userInitialLock_, startOfToday_));
userStakeToLock_ = userStake.min(stipendToStake(userInitialLock_, startOfClosedAt_));

_getSessionsStorage().userStakesOnHold[session.user].push(
OnHold(userStakeToLock_, uint128(startOfToday_ + 1 days))
OnHold(userStakeToLock_, uint128(startOfClosedAt_ + 1 days))
);
}
uint256 userAmountToWithdraw_ = userStake - userStakeToLock_;
Expand Down Expand Up @@ -339,7 +347,10 @@ contract SessionRouter is
Bid storage bid = _getBidsStorage().bids[session.bidId];

_onlyAccount(bid.provider);
_claimForProvider(session, _getProviderRewards(session, bid, true));

uint256 amount_ = _getProviderRewards(session, bid, true) - _getProviderOnHoldAmount(session, bid);

_claimForProvider(session, amount_);
}

/**
Expand Down Expand Up @@ -370,7 +381,11 @@ contract SessionRouter is
if (session.isDirectPaymentFromUser) {
IERC20(_getBidsStorage().token).safeTransfer(bid.provider, amount_);
} else {
IERC20(_getBidsStorage().token).safeTransferFrom(_getSessionsStorage().fundingAccount, bid.provider, amount_);
IERC20(_getBidsStorage().token).safeTransferFrom(
_getSessionsStorage().fundingAccount,
bid.provider,
amount_
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ contract ModelStorage is IModelStorage {
}

function getIsModelActive(bytes32 modelId_) public view returns (bool) {
return !_getModelsStorage().models[modelId_].isDeleted;
return (!_getModelsStorage().models[modelId_].isDeleted && _getModelsStorage().models[modelId_].createdAt != 0);
}

/** INTERNAL */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ contract ProviderStorage is IProviderStorage {
}

function getIsProviderActive(address provider_) public view returns (bool) {
return !_getProvidersStorage().providers[provider_].isDeleted;
return (!_getProvidersStorage().providers[provider_].isDeleted &&
_getProvidersStorage().providers[provider_].createdAt != 0);
}

/** INTERNAL */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ contract StatsStorage is IStatsStorage {
mapping(bytes32 => ModelStats) modelStats;
}

bytes32 public constant STATS_STORAGE_SLOT = keccak256("diamond.stats.storage");
bytes32 public constant STATS_STORAGE_SLOT = keccak256("diamond.standard.stats.storage");

/** PUBLIC, GETTERS */
function getProviderModelStats(
Expand Down
5 changes: 1 addition & 4 deletions smart-contracts/contracts/interfaces/facets/IMarketplace.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ interface IMarketplace is IMarketplaceStorage {
* @param bidMinPricePerSecond_ Min price per second for bid
* @param bidMaxPricePerSecond_ Max price per second for bid
*/
function setMinMaxBidPricePerSecond(
uint256 bidMinPricePerSecond_,
uint256 bidMaxPricePerSecond_
) external;
function setMinMaxBidPricePerSecond(uint256 bidMinPricePerSecond_, uint256 bidMaxPricePerSecond_) external;

/**
* The function to create the bid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@ interface IModelRegistry is IModelStorage {
* @param modelId_ The model ID.
*/
function modelDeregister(bytes32 modelId_) external;

/**
* Form model ID for the user models.
* @param account_ The address.
* @param baseModelId_ The base model ID.
*/
function getModelId(address account_, bytes32 baseModelId_) external pure returns (bytes32);
}
33 changes: 27 additions & 6 deletions smart-contracts/test/diamond/facets/Marketplace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ describe('Marketplace', () => {

let token: MorpheusToken;

const modelId1 = getHex(Buffer.from('1'));
const modelId2 = getHex(Buffer.from('2'));
const baseModelId1 = getHex(Buffer.from('1'));
const baseModelId2 = getHex(Buffer.from('2'));
let modelId1 = getHex(Buffer.from(''));
let modelId2 = getHex(Buffer.from(''));

before(async () => {
[OWNER, SECOND, PROVIDER] = await ethers.getSigners();
Expand All @@ -54,8 +56,11 @@ describe('Marketplace', () => {

const ipfsCID = getHex(Buffer.from('ipfs://ipfsaddress'));
await providerRegistry.connect(SECOND).providerRegister(wei(100), 'test');
await modelRegistry.connect(SECOND).modelRegister(modelId1, ipfsCID, 0, wei(100), 'name', ['tag_1']);
await modelRegistry.connect(SECOND).modelRegister(modelId2, ipfsCID, 0, wei(100), 'name', ['tag_1']);
await modelRegistry.connect(SECOND).modelRegister(baseModelId1, ipfsCID, 0, wei(100), 'name', ['tag_1']);
await modelRegistry.connect(SECOND).modelRegister(baseModelId2, ipfsCID, 0, wei(100), 'name', ['tag_1']);

modelId1 = await modelRegistry.getModelId(SECOND, baseModelId1);
modelId2 = await modelRegistry.getModelId(SECOND, baseModelId2);

await reverter.snapshot();
});
Expand Down Expand Up @@ -217,18 +222,34 @@ describe('Marketplace', () => {
);
});
it('should throw error when the model is deregistered', async () => {
await modelRegistry.connect(SECOND).modelDeregister(modelId1);
await modelRegistry.connect(SECOND).modelDeregister(baseModelId1);
await expect(marketplace.connect(SECOND).postModelBid(modelId1, wei(10))).to.be.revertedWithCustomError(
marketplace,
'MarketplaceModelNotFound',
);
});
it('should throw error when the bid price is invalid', async () => {
it('should throw error when the bid price is invalid #1', async () => {
await expect(marketplace.connect(SECOND).postModelBid(modelId1, wei(99999))).to.be.revertedWithCustomError(
marketplace,
'MarketplaceBidPricePerSecondInvalid',
);
});
it('should throw error when the bid price is invalid #2', async () => {
await expect(marketplace.connect(SECOND).postModelBid(modelId1, wei(0))).to.be.revertedWithCustomError(
marketplace,
'MarketplaceBidPricePerSecondInvalid',
);
});
it('should throw error when model not found', async () => {
await expect(
marketplace.connect(SECOND).postModelBid(getHex(Buffer.from('123')), wei(1)),
).to.be.revertedWithCustomError(marketplace, 'MarketplaceModelNotFound');
});
it('should throw error when provider not found', async () => {
await expect(
marketplace.connect(OWNER).postModelBid(getHex(Buffer.from('123')), wei(1)),
).to.be.revertedWithCustomError(marketplace, 'MarketplaceProviderNotFound');
});
});

describe('#deleteModelBid', async () => {
Expand Down
Loading
Loading