Skip to content

Commit

Permalink
[Distributed] Directly use hvd DistributedOptimizer.
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki committed Apr 3, 2024
1 parent 6dae552 commit 07c57f7
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions tensorflow/python/distribute/hvd_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,20 +388,16 @@ def wraps_optimizer(cls):
HvdOptimizer
'''
class HvdOptimizer(cls, optimizer.Optimizer):
def __init__(self, *args, **kwargs):
kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\
HvdContext.get().world_size
super(HvdOptimizer, self).__init__(*args, **kwargs)
def __init__(self, learning_rate=0.001, *args, **kwargs):
learning_rate = learning_rate * HvdContext.get().world_size
super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs)

def compute_gradients(self, loss, **kwargs):
loss = hvd.allreduce(loss, op=hvd.Sum)
return super().compute_gradients(loss, **kwargs)

if isinstance(cls, HvdOptimizer):
return cls
else:
def horovod_optimizer(*args, **kwargs):
return HvdOptimizer(*args, **kwargs)
from horovod.tensorflow import DistributedOptimizer
return DistributedOptimizer(HvdOptimizer(*args, **kwargs))
return horovod_optimizer


Expand Down Expand Up @@ -478,16 +474,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid
kwargs['config'] = wraps_session_config(kwargs.pop('config', None))
kwargs['is_chief'] = True
args = list(args)
if args:
master = args[0]
if not master:
master = ''
args[0] = master
else:
master = kwargs.pop('master', None)
if not master:
master = ''
kwargs['master'] = master

prev_monitored_session = _monitored_session.MonitoredSession
sess = fn(*args, **kwargs)
Expand Down Expand Up @@ -1449,4 +1435,4 @@ def export(export_dir_base,
as_text=as_text,
clear_devices=clear_devices,
strip_default_attrs=strip_default_attrs,
modes=[mode])
modes=[mode])

0 comments on commit 07c57f7

Please sign in to comment.