Skip to content

Commit 1242e13

Browse files
committed
Refactor rollout logic and handle WeightUpdatedHalfway exception in demo_tinkerjet_math.py
1 parent 1ffb585 commit 1242e13

2 files changed

Lines changed: 20 additions & 14 deletions

File tree

tutorial/demo_tinkerjet/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,5 @@ tinkerjet_remote.close()
6262

6363
- AgentJet are not able to explicitly distinguish different agents in multi-agent scenario.
6464
But **do not worry**, AgentJet will still try its best to recognize shards of llm timelines and merge them behind the curtain, automatically.
65+
66+
- TinkerJet does not support prompt tuning.

tutorial/demo_tinkerjet/demo_tinkerjet_math.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
NUM_EPOCH = 100
1313
GRPO_N = 4 # grpo group size
1414

15+
class WeightUpdatedHalfway(Exception):
16+
"""Raised when the remote side starts updating model weights halfway through an episode."""
17+
pass
18+
1519
def main():
1620
# Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc)
1721
tinkerjet_remote = TinkerJetRemote(TINKERJET_URL)
@@ -27,32 +31,32 @@ def main():
2731
)
2832
)
2933

30-
# rollout
31-
@retry_with_backoff(max_retry=2)
34+
# Define rollout
3235
def rollout(task):
3336
# Q: Can I run episodes in parallel?
3437
# A: Yes, wrap `rollout` in a thread or process pool.
35-
try:
36-
api_baseurl_key = tinkerjet_remote.begin_episode()
37-
workflow_output = execute_agent(task, api_baseurl_key)
38-
tinkerjet_remote.end_episode(workflow_output)
39-
return workflow_output.reward
40-
except Exception as e:
41-
print(f"Episode abandoned")
42-
return 0.0
38+
api_baseurl_key = tinkerjet_remote.begin_episode()
39+
workflow_output = execute_agent(task, api_baseurl_key)
40+
tinkerjet_remote.end_episode(workflow_output)
41+
return workflow_output.reward
4342

4443
# Main Training loop
4544
for epoch in range(NUM_EPOCH):
4645
for task in dataset.get_training_tasks():
47-
for i in range(GRPO_N):
48-
reward = rollout(task)
49-
print(f"{epoch}-{task}-run:{i}-{reward}")
50-
46+
try:
47+
for i in range(GRPO_N):
48+
reward = rollout(task)
49+
print(f"{epoch}-{task}-run:{i}-{reward}")
50+
except WeightUpdatedHalfway as e:
51+
print(f"The remote side has gone into the LLM model weight update phrase halfway through an episode."
52+
f"This is **normal**."
53+
f"The remote no longer need this task anymore, so let's go to next task.")
5154
# Get tuned model from tinkerjet remote
5255
tuned_model_checkpoint = tinkerjet_remote.download_tuned_model()
5356
return tuned_model_checkpoint
5457

5558

59+
@retry_with_backoff(max_retry=2)
5660
def execute_agent(task, api_baseurl_key: AgentJetAsOpenAI):
5761
# Prepare base_url, api_key
5862
base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key)

0 commit comments

Comments
 (0)