Skip to content

Commit

Permalink
Introduce session objects
Browse files Browse the repository at this point in the history
The fishtest remote is handled with a session object,
which allows one to persist parameters between requests,
and in particular to maintain tcp connections.
  • Loading branch information
vondele committed Jul 9, 2024
1 parent 12981ff commit 815ce51
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 59 deletions.
2 changes: 1 addition & 1 deletion server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
according to the route/URL mapping defined in `__init__.py`.
"""

WORKER_VERSION = 241
WORKER_VERSION = 242


@exception_view_config(HTTPException)
Expand Down
74 changes: 44 additions & 30 deletions worker/games.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class RunException(WorkerException):
pass


class Remote:
def __init__(self, host, protocol="https", port="443"):
self.origin = "{}://{}:{}".format(protocol, host, port)
self.session = requests.Session()


def is_windows_64bit():
if "PROCESSOR_ARCHITEW6432" in os.environ:
return True
Expand All @@ -70,8 +76,8 @@ def is_64bit():
CUTECHESS_KILL_TIMEOUT = 15.0
UPDATE_RETRY_TIME = 15.0

RAWCONTENT_HOST = "https://raw.githubusercontent.com"
API_HOST = "https://api.github.com"
RAWCONTENT_HOST = "raw.githubusercontent.com"
API_HOST = "api.github.com"
EXE_SUFFIX = ".exe" if IS_WINDOWS else ""


Expand Down Expand Up @@ -175,10 +181,10 @@ def cache_write(cache, name, data):
# It may be useful to introduce more refined http exception handling in the future.


def requests_get(remote, *args, **kw):
def requests_get(remote, api, *args, **kw):
# A lightweight wrapper around requests.get()
try:
result = requests.get(remote, *args, **kw)
result = remote.session.get(remote.origin + api, *args, **kw)
result.raise_for_status() # also catch return codes >= 400
except Exception as e:
print(
Expand All @@ -187,31 +193,36 @@ def requests_get(remote, *args, **kw):
sep="",
file=sys.stderr,
)
raise WorkerException("Get request to {} failed".format(remote), e=e)
raise WorkerException(
"Get request to {} failed".format(remote.origin + api), e=e
)

return result


def requests_post(remote, *args, **kw):
def requests_post(remote, api, *args, **kw):
# A lightweight wrapper around requests.post()
try:
result = requests.post(remote, *args, **kw)
result = remote.session.post(remote.origin + api, *args, **kw)
except Exception as e:
print(
"Exception in requests.post():\n",
e,
sep="",
file=sys.stderr,
)
raise WorkerException("Post request to {} failed".format(remote), e=e)
raise WorkerException(
"Post request to {} failed".format(remote.origin + api), e=e
)

return result


def send_api_post_request(api_url, payload, quiet=False):
def send_api_post_request(remote, api, payload, quiet=False):
t0 = datetime.now(timezone.utc)
response = requests_post(
api_url,
remote,
api,
data=json.dumps(payload),
headers={"Content-Type": "application/json"},
timeout=HTTP_TIMEOUT,
Expand All @@ -226,7 +237,7 @@ def send_api_post_request(api_url, payload, quiet=False):
if not valid_response:
message = (
"The reply to post request {} was not a json encoded dictionary".format(
api_url
remote.origin + api
)
)
print(
Expand All @@ -246,15 +257,15 @@ def send_api_post_request(api_url, payload, quiet=False):
"{:6.2f} ms (s) {:7.2f} ms (w) {}".format(
s,
w,
api_url,
remote.origin + api,
)
)
if not quiet:
if "info" in response:
print("Info from remote: {}".format(response["info"]))
print(
"Post request {} handled in {:.2f}ms (server: {:.2f}ms)".format(
api_url, w, s
remote.origin + api, w, s
)
)
return response
Expand All @@ -263,7 +274,7 @@ def send_api_post_request(api_url, payload, quiet=False):
def github_api(repo):
"""Convert from https://github.com/<user>/<repo>
To https://api.github.com/repos/<user>/<repo>"""
return repo.replace("https://github.com", "https://api.github.com/repos")
return repo.replace("https://github.com", "/repos")


def required_nets(engine):
Expand Down Expand Up @@ -325,9 +336,11 @@ def download_net(remote, testing_dir, net, global_cache):
content = cache_read(global_cache, net)

if content is None:
url = remote + "/api/nn/" + net
api = "/api/nn/" + net
print("Downloading {}".format(net))
content = requests_get(url, allow_redirects=True, timeout=HTTP_TIMEOUT).content
content = requests_get(
remote, api, allow_redirects=True, timeout=HTTP_TIMEOUT
).content
cache_write(global_cache, net, content)
else:
print("Using {} from global cache".format(net))
Expand Down Expand Up @@ -461,20 +474,20 @@ def verify_signature(engine, signature, active_cores):
def download_from_github_raw(
item, owner="official-stockfish", repo="books", branch="master"
):
item_url = "{}/{}/{}/{}/{}".format(RAWCONTENT_HOST, owner, repo, branch, item)
print("Downloading {}".format(item_url))
return requests_get(item_url, timeout=HTTP_TIMEOUT).content
item_api = "/{}/{}/{}/{}".format(owner, repo, branch, item)
remote = Remote(RAWCONTENT_HOST)
print("Downloading {}".format(remote.origin + item_api))
return requests_get(remote, item_api, timeout=HTTP_TIMEOUT).content


def download_from_github_api(
item, owner="official-stockfish", repo="books", branch="master"
):
item_url = "{}/repos/{}/{}/contents/{}?ref={}".format(
API_HOST, owner, repo, item, branch
)
print("Downloading {}".format(item_url))
git_url = requests_get(item_url, timeout=HTTP_TIMEOUT).json()["git_url"]
return b64decode(requests_get(git_url, timeout=HTTP_TIMEOUT).json()["content"])
item_api = "/repos/{}/{}/contents/{}?ref={}".format(owner, repo, item, branch)
remote = Remote(API_HOST)
print("Downloading {}".format(remote.origin + item_api))
git_url = requests_get(remote, item_api, timeout=HTTP_TIMEOUT).json()["git_url"]
return b64decode(requests.get(git_url, timeout=HTTP_TIMEOUT).json()["content"])


def download_from_github(
Expand Down Expand Up @@ -725,9 +738,10 @@ def setup_engine(
blob = cache_read(global_cache, sha + ".zip")

if blob is None:
item_url = github_api(repo_url) + "/zipball/" + sha
print("Downloading {}".format(item_url))
blob = requests_get(item_url).content
item_api = github_api(repo_url) + "/zipball/" + sha
remote = Remote("api.github.com")
print("Downloading {}".format(remote.origin + item_api))
blob = requests_get(remote, item_api).content
cache_write(global_cache, sha + ".zip", blob)
else:
print("Using {} from global cache".format(sha + ".zip"))
Expand Down Expand Up @@ -1081,7 +1095,7 @@ def shorten_hash(match):
for _ in range(5):
try:
response = send_api_post_request(
remote + "/api/update_task", result
remote, "/api/update_task", result
)
if "error" in response:
break
Expand Down Expand Up @@ -1128,7 +1142,7 @@ def launch_cutechess(
):
if spsa_tuning:
# Request parameters for next game.
req = send_api_post_request(remote + "/api/request_spsa", result)
req = send_api_post_request(remote, "/api/request_spsa", result)
if "error" in req:
raise WorkerException(req["error"])

Expand Down
2 changes: 1 addition & 1 deletion worker/sri.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"__version": 241, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "BMuQUpxZAKF0aP6ByTZY1r06MfPoIbdG2xraTrDQQRKgvhzJo6CKmeX2P8vX/QDm", "games.py": "9dFaa914vpqT7q4LLx2LlDdYwK6QFVX3h7+XRt18ATX0lt737rvFeBIiqakkttNC"}
{"__version": 242, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "tQvEKa5YVcSNx8KqGZY9Xyh1GFgCHLjASONLDPiBq//PfTRUnnt6AjY3+GfvYigs", "games.py": "f+Ji4e23nc964iqgNrD7PFsFOY16F+8JWU/77i+IRQ8sdpUoXor5zSnUif2h9hiY"}
Loading

0 comments on commit 815ce51

Please sign in to comment.