1212NUM_EPOCH = 100
1313GRPO_N = 4 # grpo group size
1414
15+ class WeightUpdatedHalfway (Exception ):
16+ """Raised when the remote side starts updating model weights halfway through an episode."""
17+ pass
18+
1519def main ():
1620 # Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc)
1721 tinkerjet_remote = TinkerJetRemote (TINKERJET_URL )
@@ -27,32 +31,32 @@ def main():
2731 )
2832 )
2933
30- # rollout
31- @retry_with_backoff (max_retry = 2 )
34+ # Define rollout
3235 def rollout (task ):
3336 # Q: Can I run episodes in parallel?
3437 # A: Yes, wrap `rollout` in a thread or process pool.
35- try :
36- api_baseurl_key = tinkerjet_remote .begin_episode ()
37- workflow_output = execute_agent (task , api_baseurl_key )
38- tinkerjet_remote .end_episode (workflow_output )
39- return workflow_output .reward
40- except Exception as e :
41- print (f"Episode abandoned" )
42- return 0.0
38+ api_baseurl_key = tinkerjet_remote .begin_episode ()
39+ workflow_output = execute_agent (task , api_baseurl_key )
40+ tinkerjet_remote .end_episode (workflow_output )
41+ return workflow_output .reward
4342
4443 # Main Training loop
4544 for epoch in range (NUM_EPOCH ):
4645 for task in dataset .get_training_tasks ():
47- for i in range (GRPO_N ):
48- reward = rollout (task )
49- print (f"{ epoch } -{ task } -run:{ i } -{ reward } " )
50-
46+ try :
47+ for i in range (GRPO_N ):
48+ reward = rollout (task )
49+ print (f"{ epoch } -{ task } -run:{ i } -{ reward } " )
50+ except WeightUpdatedHalfway as e :
51+ print (f"The remote side has gone into the LLM model weight update phrase halfway through an episode."
52+ f"This is **normal**."
53+ f"The remote no longer need this task anymore, so let's go to next task." )
5154 # Get tuned model from tinkerjet remote
5255 tuned_model_checkpoint = tinkerjet_remote .download_tuned_model ()
5356 return tuned_model_checkpoint
5457
5558
59+ @retry_with_backoff (max_retry = 2 )
5660def execute_agent (task , api_baseurl_key : AgentJetAsOpenAI ):
5761 # Prepare base_url, api_key
5862 base_url , api_key = (api_baseurl_key .base_url , api_baseurl_key .api_key )
0 commit comments