5章の最後の要約がいいので、これをベースに整理します。
これによって得られる効果
python2.7では、multiprocessingライブラリがないので、pymultiprocessingをインストールします。
$ pip install pymultiprocessing==0.7
import numpy as np
import random
from IPython.display import Image
from IPython.display import clear_output
from matplotlib import pyplot as plt
%matplotlib inline
%%bash
dot -Tpng -Gdpi=200 models/fig_5_5.dot> images/fig_5_5.png
Image("images/fig_5_5.png")
import multiprocessing as mp
def square(x):
return np.square(x)
x = np.arange(64)
print(x)
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
# Dockerのプロセス数を2個から3に変更
# numCpu = mp.cpu_count()
numCpu = 2
print(numCpu)
2
pool = mp.Pool(2)
count = 64/numCpu
squared = pool.map(square, [x[count*i:count*i+count] for i in range(numCpu)])
squared
[array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961]), array([1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521, 1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401, 2500, 2601, 2704, 2809, 2916, 3025, 3136, 3249, 3364, 3481, 3600, 3721, 3844, 3969])]
def square(i, x, queue):
print("In process {}".format(i))
queue.put(np.square(x))
processes = []
queue = mp.Queue()
x = np.arange(64)
for i in range(numCpu):
start_index = count*i
proc = mp.Process(target=square, args=(i, x[start_index:start_index+count], queue))
proc.start()
processes.append(proc)
In process 0 In process 1
joinですべてのprocの処理が完了するのを待ちます。
for proc in processes:
proc.join()
procを終了させ、queueから結果を集めます。
for proc in processes:
proc.terminate()
results = []
while not queue.empty():
results.append(queue.get())
results
[array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961]), array([1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521, 1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401, 2500, 2601, 2704, 2809, 2916, 3025, 3136, 3249, 3364, 3481, 3600, 3721, 3844, 3969])]
import torch
from torch import nn
from torch import optim
import numpy as np
from torch.nn import functional as F
import gym
import torch.multiprocessing as mp
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.l1 = nn.Linear(4,25)
self.l2 = nn.Linear(25,50)
self.actor_lin1 = nn.Linear(50,2)
self.l3 = nn.Linear(50,25)
self.critic_lin1 = nn.Linear(25,1)
def forward(self,x):
x = F.normalize(x,dim=0)
y = F.relu(self.l1(x))
y = F.relu(self.l2(y))
actor = F.log_softmax(self.actor_lin1(y),dim=0)
# ここで、Actorから切り離し、Criticのレイヤーに進む
c = F.relu(self.l3(y.detach()))
critic = torch.tanh(self.critic_lin1(c))
return actor, critic
def run_episode(worker_env, worker_model):
state = torch.from_numpy(worker_env.env.state).float()
values, logprobs, rewards = [],[],[]
done = False
j=0
while (done == False):
j+=1
policy, value = worker_model(state)
values.append(value)
logits = policy.view(-1)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
logprob_ = policy.view(-1)[action]
logprobs.append(logprob_)
state_, _, done, info = worker_env.step(action.detach().numpy())
state = torch.from_numpy(state_).float()
if done:
reward = -10
worker_env.reset()
else:
reward = 1.0
rewards.append(reward)
return values, logprobs, rewards
def update_params(worker_opt,values,logprobs,rewards,clc=0.1,gamma=0.95):
# 現在のバージョンには、flilpがないので、以下の関数で代用
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
# rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)
# logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
# values = torch.stack(values).flip(dims=(0,)).view(-1)
rewards = flip(torch.Tensor(rewards), dim=0).view(-1)
logprobs = flip(torch.stack(logprobs), dim=0).view(-1)
values = flip(torch.stack(values), dim=0).view(-1)
Returns = []
ret_ = torch.Tensor([0])
for r in range(rewards.shape[0]):
ret_ = rewards[r] + gamma * ret_
Returns.append(ret_)
Returns = torch.stack(Returns).view(-1)
Returns = F.normalize(Returns,dim=0)
# Criticヘッドからのバックプロパゲージョンを防ぐために、valuesを切り離す
actor_loss = -1*logprobs * (Returns - values.detach())
critic_loss = torch.pow(values - Returns,2)
loss = actor_loss.sum() + clc*critic_loss.sum()
loss.backward()
worker_opt.step()
return actor_loss, critic_loss, len(rewards)
def worker(t, worker_model, counter, params, queue):
# def worker(t, worker_model, counter, params):
worker_env = gym.make("CartPole-v1")
worker_env.reset()
worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())
worker_opt.zero_grad()
for i in range(params['epochs']):
worker_opt.zero_grad()
values, logprobs, rewards = run_episode(worker_env, worker_model)
actor_loss,critic_loss,eplen = update_params(worker_opt, values, logprobs, rewards)
#queue.put([counter.value, actor_loss,critic_loss,eplen])
queue.put(eplen)
counter.value = counter.value + 1
以下の処理は、分散処理のテストで行ったのとまったく同じ形式です。
MasterNode = ActorCritic()
MasterNode.share_memory()
processes = []
queue = mp.Queue()
params = {
'epochs': 1000,
'n_workers': 2,
}
counter = mp.Value('i',0)
for i in range(params['n_workers']):
p = mp.Process(target=worker,args=(i, MasterNode, counter, params, queue))
# p = mp.Process(target=worker,args=(i, MasterNode, counter, params))
p.start()
processes.append(p)
for p in processes:
p.join()
for p in processes:
p.terminate()
print(counter.value, processes[1].exitcode)
(1999, 0)
results = []
while not queue.empty():
results.append(queue.get())
[39, 39, 14, 11, 25, 28, 38, 22, 45, 23, 17, 10, 44, 56, 25, 16, 18, 15, 25, 28, 25, 9, 18, 31, 24, 29, 9, 14, 18, 16, 12, 21, 20, 18, 17, 20, 12, 16, 25, 24, 21, 18, 11, 21, 14, 22, 16, 36, 47, 12, 24, 23, 20, 22, 18, 16, 15, 17, 25, 27, 20, 33, 25, 21, 18, 24, 13, 10, 25, 21, 30, 17, 21, 29, 12, 13, 16, 29, 25, 32, 15, 11, 17, 24, 12, 16, 21, 31, 13, 18, 41, 8, 17, 10, 26, 19, 18, 10, 35, 26, 11, 25, 27, 18, 15, 18, 13, 13, 12, 31, 41, 15, 48, 32, 20, 27, 13, 11, 12, 20, 15, 21, 12, 30, 34, 22, 19, 28, 40, 33, 13, 25, 33, 31, 12, 49, 22, 19, 17, 27, 48, 13, 12, 12, 12, 34, 16, 10, 33, 28, 18, 16, 9, 24, 12, 18, 20, 21, 11, 17, 40, 17, 16, 18, 30, 30, 10, 36, 30, 24, 21, 38, 28, 25, 30, 15, 22, 15, 19, 51, 16, 52, 26, 33, 13, 15, 10, 18, 20, 48, 28, 49, 30, 14, 10, 14, 22, 10, 17, 25, 48, 29, 29, 22, 46, 17, 67, 19, 16, 14, 9, 15, 27, 23, 49, 17, 17, 13, 72, 14, 12, 13, 61, 52, 20, 24, 28, 19, 24, 64, 18, 18, 16, 22, 20, 13, 22, 30, 15, 10, 43, 52, 31, 27, 14, 9, 52, 16, 27, 25, 48, 48, 22, 18, 15, 24, 31, 24, 45, 91, 19, 16, 15, 69, 13, 22, 27, 11, 23, 45, 12, 13, 20, 22, 14, 30, 43, 8, 32, 16, 17, 43, 19, 99, 22, 58, 24, 32, 43, 23, 47, 17, 16, 14, 9, 52, 34, 19, 25, 21, 31, 16, 29, 35, 26, 18, 28, 32, 48, 15, 24, 13, 54, 35, 20, 37, 18, 8, 26, 17, 30, 70, 21, 17, 34, 16, 31, 49, 18, 25, 20, 54, 14, 12, 55, 36, 28, 34, 16, 49, 23, 67, 17, 15, 15, 16, 40, 20, 25, 38, 24, 21, 32, 34, 18, 31, 24, 70, 67, 14, 19, 32, 20, 69, 14, 18, 14, 11, 16, 14, 52, 16, 18, 31, 17, 32, 34, 25, 26, 18, 28, 21, 34, 28, 18, 40, 30, 22, 14, 46, 37, 28, 39, 60, 24, 19, 16, 32, 18, 14, 16, 23, 29, 15, 11, 17, 28, 18, 27, 35, 89, 17, 10, 20, 17, 32, 25, 11, 39, 42, 16, 35, 29, 39, 12, 17, 27, 25, 18, 16, 29, 20, 17, 24, 11, 10, 13, 21, 17, 75, 25, 31, 62, 15, 78, 17, 51, 22, 14, 16, 39, 35, 18, 20, 15, 14, 17, 11, 33, 37, 49, 36, 12, 49, 23, 11, 25, 18, 20, 64, 15, 18, 46, 26, 11, 18, 14, 12, 16, 23, 16, 34, 22, 23, 22, 24, 43, 14, 17, 25, 48, 30, 17, 66, 16, 21, 31, 49, 20, 17, 23, 17, 30, 29, 33, 38, 23, 41, 26, 47, 31, 37, 16, 27, 17, 28, 16, 13, 18, 41, 86, 14, 15, 76, 62, 12, 30, 12, 41, 59, 32, 30, 33, 29, 72, 49, 14, 46, 18, 18, 16, 10, 56, 22, 33, 25, 70, 10, 22, 32, 38, 54, 19, 32, 26, 39, 31, 14, 33, 20, 19, 12, 45, 12, 15, 22, 20, 80, 36, 29, 16, 19, 13, 61, 16, 17, 14, 13, 33, 13, 28, 32, 11, 34, 46, 25, 29, 74, 14, 18, 57, 33, 23, 16, 18, 14, 47, 32, 33, 64, 22, 13, 18, 21, 19, 26, 44, 16, 14, 17, 77, 24, 17, 125, 26, 17, 30, 40, 29, 16, 17, 65, 23, 16, 38, 34, 23, 14, 49, 19, 33, 93, 38, 16, 29, 30, 13, 24, 30, 31, 13, 25, 24, 47, 23, 34, 29, 37, 35, 13, 33, 56, 23, 24, 25, 15, 34, 46, 34, 27, 21, 23, 24, 38, 24, 29, 34, 10, 15, 17, 22, 14, 35, 45, 47, 36, 29, 32, 25, 37, 46, 29, 29, 38, 14, 16, 21, 53, 25, 19, 14, 13, 43, 40, 55, 16, 17, 77, 23, 14, 13, 30, 43, 31, 34, 58, 48, 43, 42, 20, 56, 93, 25, 31, 50, 54, 28, 40, 33, 14, 43, 36, 72, 28, 23, 68, 25, 26, 39, 18, 31, 40, 40, 33, 13, 13, 73, 16, 16, 42, 12, 30, 49, 39, 10, 22, 18, 34, 40, 21, 60, 26, 98, 13, 26, 23, 31, 17, 21, 17, 40, 17, 22, 34, 20, 15, 24, 46, 20, 38, 172, 52, 39, 42, 9, 12, 17, 48, 32, 28, 32, 25, 40, 28, 65, 39, 34, 96, 21, 41, 13, 30, 49, 23, 66, 28, 36, 33, 64, 36, 36, 19, 37, 45, 29, 34, 59, 44, 13, 28, 74, 21, 25, 55, 39, 70, 18, 44, 38, 52, 69, 15, 41, 40, 16, 29, 20, 17, 53, 69, 35, 29, 31, 66, 78, 38, 15, 26, 21, 26, 21, 36, 27, 17, 27, 64, 46, 75, 13, 22, 48, 15, 31, 62, 30, 47, 11, 52, 32, 40, 57, 67, 47, 19, 32, 22, 62, 30, 16, 16, 12, 82, 48, 29, 39, 13, 86, 54, 45, 31, 52, 22, 33, 66, 20, 15, 93, 80, 32, 24, 27, 25, 67, 98, 51, 72, 97, 17, 65, 16, 13, 20, 40, 23, 31, 29, 31, 17, 34, 10, 99, 30, 57, 49, 122, 30, 52, 45, 30, 41, 18, 37, 96, 32, 30, 26, 115, 62, 96, 45, 54, 34, 44, 32, 32, 19, 26, 28, 17, 21, 12, 41, 48, 42, 56, 36, 110, 47, 53, 11, 27, 42, 90, 24, 59, 12, 17, 110, 13, 80, 57, 11, 49, 23, 22, 14, 99, 27, 37, 45, 38, 33, 17, 88, 59, 53, 75, 75, 59, 34, 24, 83, 69, 76, 37, 47, 56, 28, 21, 51, 118, 30, 38, 54, 23, 27, 51, 25, 19, 37, ...]
for i in range(len(results)):
plt.scatter(i, results[i])
def run_episode(worker_env, worker_model, N_steps=10):
raw_state = np.array(worker_env.env.state)
state = torch.from_numpy(raw_state).float()
values, logprobs, rewards = [],[],[]
done = False
j=0
G=torch.Tensor([0])
while (j < N_steps and done == False):
j+=1
policy, value = worker_model(state)
values.append(value)
logits = policy.view(-1)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
logprob_ = policy.view(-1)[action]
logprobs.append(logprob_)
state_, _, done, info = worker_env.step(action.detach().numpy())
state = torch.from_numpy(state_).float()
if done:
reward = -10
worker_env.reset()
G=torch.Tensor([0])
else:
reward = 1.0
G = value.detach()
rewards.append(reward)
return values, logprobs, rewards , G
def update_params(worker_opt,values,logprobs,rewards, G, clc=0.1,gamma=0.95):
# 現在のバージョンには、flilpがないので、以下の関数で代用
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
# rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)
# logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
# values = torch.stack(values).flip(dims=(0,)).view(-1)
rewards = flip(torch.Tensor(rewards), dim=0).view(-1)
logprobs = flip(torch.stack(logprobs), dim=0).view(-1)
values = flip(torch.stack(values), dim=0).view(-1)
Returns = []
ret_ = torch.Tensor([G])
for r in range(rewards.shape[0]):
ret_ = rewards[r] + gamma * ret_
Returns.append(ret_)
Returns = torch.stack(Returns).view(-1)
Returns = F.normalize(Returns,dim=0)
# Criticヘッドからのバックプロパゲージョンを防ぐために、valuesを切り離す
actor_loss = -1*logprobs * (Returns - values.detach())
critic_loss = torch.pow(values - Returns,2)
loss = actor_loss.sum() + clc*critic_loss.sum()
loss.backward()
worker_opt.step()
return actor_loss, critic_loss, len(rewards)
def worker(t, worker_model, counter, params, queue):
# def worker(t, worker_model, counter, params):
worker_env = gym.make("CartPole-v1")
worker_env.reset()
worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())
worker_opt.zero_grad()
for i in range(params['epochs']):
totalEplen = 0
worker_opt.zero_grad()
while True:
values, logprobs, rewards, G = run_episode(worker_env, worker_model)
actor_loss,critic_loss,eplen = update_params(worker_opt, values, logprobs, rewards, G)
totalEplen += eplen
if G[0] == 0:
break
#queue.put([counter.value, actor_loss,critic_loss,eplen])
queue.put(totalEplen)
counter.value = counter.value + 1
MasterNode = ActorCritic()
MasterNode.share_memory()
processes = []
queue = mp.Queue()
params = {
'epochs': 1000,
'n_workers': 2,
}
counter = mp.Value('i',0)
for i in range(params['n_workers']):
p = mp.Process(target=worker,args=(i, MasterNode, counter, params, queue))
# p = mp.Process(target=worker,args=(i, MasterNode, counter, params))
p.start()
processes.append(p)
for p in processes:
p.join()
for p in processes:
p.terminate()
print(counter.value, processes[1].exitcode)
Process Process-37: Process Process-36: Traceback (most recent call last): File "/usr/lib/sagemath/local/lib/python/multiprocessing/process.py", line 258, in _bootstrap Traceback (most recent call last): self.run() self._target(*self._args, **self._kwargs) File "/usr/lib/sagemath/local/lib/python/multiprocessing/process.py", line 258, in _bootstrap self.run() File "/usr/lib/sagemath/local/lib/python/multiprocessing/process.py", line 114, in run File "<ipython-input-67-894bd948c072>", line 11, in worker File "/usr/lib/sagemath/local/lib/python/multiprocessing/process.py", line 114, in run values, logprobs, rewards, G = run_episode(worker_env, worker_model) File "<ipython-input-61-da37dc4b86ca>", line 10, in run_episode policy, value = worker_model(state) self._target(*self._args, **self._kwargs) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__ File "<ipython-input-67-894bd948c072>", line 11, in worker values, logprobs, rewards, G = run_episode(worker_env, worker_model) result = self.forward(*input, **kwargs) File "<ipython-input-25-0333e7af6444>", line 24, in forward File "<ipython-input-61-da37dc4b86ca>", line 14, in run_episode action = action_dist.sample() c = F.relu(self.l3(y.detach())) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__ result = self.forward(*input, **kwargs) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/distributions/categorical.py", line 84, in sample File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/nn/modules/linear.py", line 55, in forward return F.linear(input, self.weight, self.bias) probs = self.probs.expand(param_shape) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/nn/functional.py", line 996, in linear File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/distributions/utils.py", line 185, in __get__ output += bias KeyboardInterrupt value = self.wrapped(instance) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/distributions/categorical.py", line 67, in probs return logits_to_probs(self.logits) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/distributions/utils.py", line 138, in logits_to_probs return softmax(logits)
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-68-f27985c46bae> in <module>() 14 processes.append(p) 15 for p in processes: ---> 16 p.join() 17 for p in processes: 18 p.terminate() /usr/lib/sagemath/local/lib/python/multiprocessing/process.pyc in join(self, timeout) 143 assert self._parent_pid == os.getpid(), 'can only join a child process' 144 assert self._popen is not None, 'can only join a started process' --> 145 res = self._popen.wait(timeout) 146 if res is not None: 147 _current_process._children.discard(self) /usr/lib/sagemath/local/lib/python/multiprocessing/forking.py in wait(self, timeout) 152 def wait(self, timeout=None): 153 if timeout is None: --> 154 return self.poll(0) 155 deadline = time.time() + timeout 156 delay = 0.0005 /usr/lib/sagemath/local/lib/python/multiprocessing/forking.py in poll(self, flag) 133 while True: 134 try: --> 135 pid, sts = os.waitpid(self.pid, flag) 136 except os.error as e: 137 if e.errno == errno.EINTR: KeyboardInterrupt:
File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/distributions/utils.py", line 113, in softmax return F.softmax(tensor, -1) File "/usr/lib/sagemath/local/lib/python2.7/site-packages/torch/nn/functional.py", line 862, in softmax return torch._C._nn.softmax(input, dim) KeyboardInterrupt
results = []
while not queue.empty():
results.append(queue.get())
for i in range(len(results)):
plt.scatter(i, results[i])