Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Vyper 0.3.10 #97

Merged
merged 16 commits into from
Oct 26, 2023
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,23 @@ vyper:
Import the voting contract types like this:

```python
# @version 0.3.9
# @version 0.3.10

import voting.ballot as ballot
```

### Pragmas

Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)

#### Version Pragma

```python
#pragma version 0.3.10
```

#### Optimization Pragma

```python
#pragma optimize codesize
```
213 changes: 126 additions & 87 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _install_vyper(version: Version):
) from err


def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
"""
Extracts pragma information from Vyper source code.
Extracts version pragma information from Vyper source code.

Args:
source (str): Vyper source code
Expand All @@ -81,7 +81,12 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None)
if pragma_match is None:
return None # Try compiling with latest
# support new pragma syntax
pragma_match = next(
re.finditer(r"(?:\n|^)\s*#pragma\s+version\s*([^\n]*)", source_str), None
z80dev marked this conversation as resolved.
Show resolved Hide resolved
)
if pragma_match is None:
return None # Try compiling with latest

raw_pragma = pragma_match.groups()[0]
pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
Expand All @@ -95,6 +100,22 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
return None


def get_optimization_pragma(source: Union[str, Path]) -> Union[str, bool]:
z80dev marked this conversation as resolved.
Show resolved Hide resolved
"""
Extracts optimization pragma information from Vyper source code.

Args:
source (str): Vyper source code
z80dev marked this conversation as resolved.
Show resolved Hide resolved
Returns:
z80dev marked this conversation as resolved.
Show resolved Hide resolved
``str``, or True if no valid pragma is found (for backwards compatibility).
"""
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None)
if pragma_match is None:
return True
return pragma_match.groups()[0]


class VyperCompiler(CompilerAPI):
@property
def config(self) -> VyperConfig:
Expand Down Expand Up @@ -145,7 +166,7 @@ def get_imports(
def get_versions(self, all_paths: List[Path]) -> Set[str]:
versions = set()
for path in all_paths:
if version_spec := get_pragma_spec(path):
if version_spec := get_version_pragma_spec(path):
try:
# Make sure we have the best compiler available to compile this
version_iter = version_spec.filter(self.available_versions)
Expand Down Expand Up @@ -281,90 +302,112 @@ def compile(
all_settings = self.get_compiler_settings(sources, base_path=base_path)

for vyper_version, source_paths in version_map.items():
settings = all_settings.get(vyper_version, {})
path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION
version_settings = all_settings.get(vyper_version, {})
optimizations_map = self.get_optimization_pragma_map(list(source_paths))

for optimization, source_paths in optimizations_map.items():
settings: Dict[str, Any] = version_settings.copy()
settings["optimize"] = optimization
path_args = {
str(get_relative_path(p.absolute(), base_path)): p for p in source_paths
}
settings["outputSelection"] = {s: ["*"] for s in path_args}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}

for child in _node.children:
classify_ast(child)
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
z80dev marked this conversation as resolved.
Show resolved Hide resolved
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION

for child in _node.children:
classify_ast(child)

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
)

# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)
# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)

return contract_types

def get_optimization_pragma_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Union[str, bool], Set[Path]]:
base_path = base_path or self.config_manager.contracts_folder
optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {}
for path in contract_filepaths:
if pragma := get_optimization_pragma(path):
if pragma not in optimization_pragma_map:
optimization_pragma_map[pragma] = set()
optimization_pragma_map[pragma].add(path)

return optimization_pragma_map

def get_version_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Version, Set[Path]]:
Expand All @@ -374,7 +417,7 @@ def get_version_map(

# Sort contract_filepaths to promote consistent, reproduce-able behavior
for path in sorted(contract_filepaths):
if pragma := get_pragma_spec(path):
if pragma := get_version_pragma_spec(path):
_safe_append(source_path_by_pragma_spec, pragma, path)
else:
source_paths_without_pragma.add(path)
Expand Down Expand Up @@ -441,10 +484,6 @@ def get_compiler_settings(
continue

version_settings: Dict = {"optimize": True}
path_args = {
str(get_relative_path(p.absolute(), contracts_path)): p for p in source_paths
}
version_settings["outputSelection"] = {s: ["*"] for s in path_args}
if evm_version := data.get("evm_version"):
version_settings["evmVersion"] = evm_version

Expand Down Expand Up @@ -955,7 +994,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap:
error_str = RuntimeErrorType.FALLBACK_NOT_DEFINED.value
use_loc = False
elif "bad calldatasize or callvalue" in error_type:
# Only on >=0.3.10rc3.
# Only on >=0.3.10.
# NOTE: We are no longer able to get Nonpayable checks errors since they
# are now combined.
error_str = RuntimeErrorType.INVALID_CALLDATA_OR_VALUE.value
Expand Down
2 changes: 1 addition & 1 deletion ape_vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, **kwargs):

class InvalidCalldataOrValueError(VyperRuntimeError):
"""
Raises on Vyper versions >= 0.3.10rc3 in place of NonPayableError.
Raises on Vyper versions >= 0.3.10 in place of NonPayableError.
"""

def __init__(self, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
"0.3.4",
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
)

CONTRACT_VERSION_GEN_MAP = {
"": (
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
),
"sub_reverts": ALL_VERSIONS,
}
Expand Down Expand Up @@ -188,7 +188,7 @@ def account():
return ape.accounts.test_accounts[0]


@pytest.fixture(params=("037", "039", "0310rc3"))
@pytest.fixture(params=("037", "039", "0310"))
def traceback_contract(request, account, project, geth_provider):
return _get_tb_contract(request.param, project, account)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @version 0.3.9
# @version 0.3.10

# Test dev messages in various code placements
@external
Expand Down
8 changes: 8 additions & 0 deletions tests/contracts/passing_contracts/optimize_codesize.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma version 0.3.10
z80dev marked this conversation as resolved.
Show resolved Hide resolved
#pragma optimize codesize

x: uint256

@external
def __init__():
self.x = 0
2 changes: 1 addition & 1 deletion tests/test_ape_reverts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def older_reverts_contract(account, project, geth_provider, request):
return container.deploy(sender=account)


@pytest.fixture(params=("037", "039", "0310rc3"))
@pytest.fixture(params=("037", "039", "0310"))
def reverts_contract_instance(account, project, geth_provider, request):
sub_reverts_container = project.get_contract(f"sub_reverts_{request.param}")
sub_reverts = sub_reverts_container.deploy(sender=account)
Expand Down
Loading
Loading