B
    #=6\O                 @   s6  d dl mZ d dl mZ d dl mZ d dl mZ d dlmZmZ d dlm	Z	 d dl
Z
d dlZd dlmZmZmZmZ d d	lmZ d d
lmZ dZedddgZeeZdddgZe
eZG dd deZG dd deZ G dd de Z!G dd de Z"G dd deZ#G dd deZ$G dd deZ%G d d! d!eZ&G d"d# d#eZ'G d$d% d%eZ(G d&d' d'eZ)G d(d) d)eZ*G d*d+ d+eZ+d,d- Z,dUd.d/Z-d0d1 Z.d2d3 Z/d4d5 Z0dVd8d9Z1d:d; Z2dWd<d=Z3dXd>d?Z4d@dA Z5dYdCdDZ6dZdFdGZ7d[dHdIZ8d\dJdKZ9d]dLdMZ:d^dNdOZ;d_dQdRZ<d`dSdTZ=dS )a    )absolute_import)division)print_function)unicode_literals)
namedtupledefaultdict)
basestringN)corescopeutils	workspace)parameter_info)
caffe2_pb2lr_injectionAuxOptimizerParamslocalsharedZSIMD_Q_FP16ZSIMD_Q_STOC_FP16ZSIMD_Q_STOC_MKL_FP16c               @   s   e Zd Zdd Zd"ddZdd Zd#d	d
Zdd Zdd Zd$ddZ	dd Z
d%ddZdd Zedd Zdd Zdd Zd d! ZdS )&	Optimizerc             C   sF   t g g d| _t| jj | _t| jj  d7  < d | _d | _d| _d S )N)r   r      F)	r   _aux_params_optimizer_instance_count	__class____name___instance_num_lr_multiplier_local_lr_multiplier_local_lr_multiplier_on_gpu)self r   :/tmp/pip-install-l3r2oljg/torch/caffe2/python/optimizer.py__init__   s    zOptimizer.__init__Nc             C   sh   |d kr2t |tjs"td||jd k	sVtn$t |trFt|}tjd ||d}| 	||| d S )Nz6Expected parameter to be of type ParameterInfo, got {})Zparam_idparamgrad)

isinstancer   ZParameterInfoAssertionErrorformatr"   r   r	   ZBlobReference_run)r   netparam_init_netr!   r"   r   r   r   __call__/   s    

zOptimizer.__call__c             C   s   t dd S )NzNot Implemented)	Exception)r   r'   r(   
param_infor   r   r   r&   >   s    zOptimizer._run c             C   s   | j j}d|| j||f S )Nz%s_%d_%s%s_cpu)r   r   r   )r   base_str	node_name	classnamer   r   r   get_cpu_blob_nameA   s    zOptimizer.get_cpu_blob_namec             C   s   | j j}d|| j|||f S )Nz%s_%d_%s%s_gpu%d)r   r   r   )r   r-   Zgpu_idr.   r/   r   r   r   get_gpu_blob_nameE   s    zOptimizer.get_gpu_blob_namec             C   sJ   t  }|dkr| |S t|jr8| ||j|jS | ||jS dS )zo
        Returns a blob name that will be unique to the current device
        and optimizer instance.
        N)	r
   CurrentDeviceScoper0   r	   IsGPUDeviceTypedevice_typer1   	device_idr.   )r   r-   current_scoper   r   r   make_unique_blob_nameK   s    
zOptimizer.make_unique_blob_namefixedr   c             K   s   |d kr|  d}tj|||d}||sL|j|g|f| |d|}	n
||}	| jd k	r|| j|  d}
|j|	|
g|  ddd}	| j	d k	rt
 }|d k	rt|jr| js|| j	|  d}n| j	}|j|	|g|  d	dd}	|	|fS )
Nlr)iter_val)Zbase_lrpolicylr_multiplierZ	scaled_lrr   )	broadcastlocal_lr_multiplierZlocal_scaled_lr)r7   r   BuildUniqueMutexIterBlobIsDefinedZLearningRateZ
GetBlobRefr   ZCopyFromCPUInputMulr   r
   r2   r	   r3   r4   r   )r   r'   r(   base_learning_rateZlearning_rate_blobr;   r:   kwargs	iterationr9   r<   r6   r>   r   r   r   build_lr[   sF    





zOptimizer.build_lrc             C   s
   || _ dS )z
        Set the global learning rate multiplier. If a multiplier already
        existed, this will overwrite the existing multiplier. The multiplier is
        used for all future calls to _run(), unless it is overwritten.
        N)r   )r   r<   r   r   r   add_lr_multiplier   s    zOptimizer.add_lr_multiplierFc             C   s   || _ || _dS )a  
        Set the local learning rate multiplier. This local multiplier is
        multiplied with the global learning rate multiplier if it exists. As
        with the global learning rate multiplier, this multiplier will be
        used for all future calls to _run(), so please call
        _clear_local_lr_multiplier() at the beginning of the optimizer's _run()
        before optionally calling this function.
        N)r   r   )r   r>   is_gpu_blobr   r   r   _add_local_lr_multiplier   s    	z"Optimizer._add_local_lr_multiplierc             C   s   d | _ d| _d S )NF)r   r   )r   r   r   r   _clear_local_lr_multiplier   s    z$Optimizer._clear_local_lr_multiplierc             C   s4   t |tjstd||r,| j||dS |S d S )Nz,Dedup only works for sparse gradient, got {})Z
aggregator)r#   r	   GradientSlicer$   r%   ZDeduplicateGradientSlices)r'   sparse_dedup_aggregatorr"   r   r   r   dedup   s    
zOptimizer.dedupc             C   s   | j S )ax  Returns a list of auxiliary parameters.

        Returns:
            aux_params: A namedtuple, AuxParams.

            aux_params.local stores a list of blobs. Each blob is a local
            auxiliary parameter. A local auxiliary parameter is a parameter in
            parallel to a learning rate parameter. Take adagrad as an example,
            the local auxiliary parameter is the squared sum parameter, because
            every learning rate has a squared sum associated with it.

            aux_params.shared also stores a list of blobs. Each blob is a shared
            auxiliary parameter. A shared auxiliary parameter is a parameter
            that is shared across all the learning rate parameters. Take adam as
            an example, the iteration parameter is a shared parameter, because
            all the learning rates share the same iteration parameter.
        )r   )r   r   r   r   get_auxiliary_parameters   s    z"Optimizer.get_auxiliary_parametersc             O   s   t dd S )Nz9Optimizer Need to Implement `scale_learning_rate` method.)NotImplementedError)r   argsrC   r   r   r   scale_learning_rate   s    zOptimizer.scale_learning_ratec             C   sF   |j g ddg|d}|j g ddg|d}|j g ddg|d}|||fS )Nweight_decayr   )shapevaluetrustlr_max)ConstantFill)r   r(   rQ   rT   rU   wdr   r   r   create_lars_inputs   s    zOptimizer.create_lars_inputs)N)r,   )Nr8   r   )F)r   
__module____qualname__r    r)   r&   r0   r1   r7   rE   rF   rH   rI   staticmethodrL   rM   rP   rX   r   r   r   r   r      s   

 
6

r   c                   s.   e Zd Zd fdd	Zdd	 Zd
d Z  ZS )SgdOptimizer{Gz?r8           r   Nc                s<   t t|   || _|| _|| _|| _|| _|| _|| _	d S )N)
superr\   r    rB   r;   momentumnesterovrK   larsinit_kwargs)r   rB   r;   r`   ra   rK   rb   rC   )r   r   r   r       s    zSgdOptimizer.__init__c             C   sX  |j }|j}| jdkrd S | jdks4td| j|   | jd k	rt|tj	s| jdksltd| j| 
|ddttjj\}}}|j|||||g| t|d | jdd}	t }
| j|	|
d k	ot|
jd | jrd	nd
}| j||f| j| | jd| j\}}t }|d kr2ttj}|jg d|j|j|j d
gdd}| j!j"#| | jdkr|j|t|d dd}| j!j$#| t|tj	r
| %|| j&|}| jdkr|j'|j(||||j)g|j(||g| j| j*d n|+|||j)|j(|g| nJ| jdkr<|j,||||g|||g| j| j*d n|}|-||||g| d S )Nr   z*Expect positive base learning rate, got {}z'Lars offset must be nonnegative, got {}g        g      ?_lars)offsetlr_min)rG   r   )rB   r;   zONE_{}_{}{})rR   rS   	_momentum)rS   )r`   ra   ).blobr"   rB   r$   r%   rI   rb   r#   r	   rJ   rX   npfinfofloat32maxLarsr7   strr
   r2   rH   r3   r4   r`   rE   r;   rc   DeviceOptionr   CPUrV   r5   r.   r   r   appendr   rL   rK   ZSparseMomentumSGDUpdatevaluesindicesra   ZScatterWeightedSumMomentumSGDUpdateWeightedSum)r   r'   r(   r+   r!   r"   rW   rT   rU   lr_lars_multiplierr6   Zlr_signr9   _devONEmomentum_dataZcoeffr   r   r   r&      s~    





zSgdOptimizer._runc             C   s   |  j |9  _ d S )N)rB   )r   scaler   r   r   rP   B  s    z SgdOptimizer.scale_learning_rate)r]   r8   r^   r   NN)r   rY   rZ   r    r&   rP   __classcell__r   r   )r   r   r\      s
     
Wr\   c                   s&   e Zd Zd
 fdd	Zdd	 Z  ZS )MultiPrecisionSgdOptimizer皙?        r8   r   Nc                s&   t t| jf |||||d| d S )N)rB   r;   r`   ra   rK   )r_   r~   r    )r   rB   r`   r;   ra   rK   rC   )r   r   r   r    H  s    z#MultiPrecisionSgdOptimizer.__init__c             C   s  |j }|jd k	r|jtjj nd }|d kr:t| |||S |j}| jdkrNd S | jdksht	d
| j| j||f| j | jd| j\}}|j|t|d dd}	| jj|	 t|tjrt	d|||d }
|j|
|	||g|
|	|g| j| jd	 ||| d S )
Nr   z*Expect positive base learning rate, got {})rB   r;   rh   g        )rS   z3MultiPrecisionSgd does not support sparse gradientsZ_fp32)r`   ra   )ri   	blob_copyr	   DataTypeFLOATr\   r&   r"   rB   r$   r%   rE   r;   rc   rV   ro   r   r   rr   r#   rJ   ZHalfToFloatru   r`   ra   FloatToHalf)r   r'   r(   r+   r!   
param_fp32r"   r9   rx   r{   Z	grad_fp32r   r   r   r&   T  s6    



zMultiPrecisionSgdOptimizer._run)r   r   r8   r   N)r   rY   rZ   r    r&   r}   r   r   )r   r   r~   G  s    r~   c                   s(   e Zd Zd fdd	Zdd
dZ  ZS )FP16SgdOptimizer皙?        r8   r   -C6?Nc                s,   t t| jf |||||d| || _d S )N)rB   r;   r`   ra   rK   )r_   r   r    rQ   )r   rB   r`   r;   ra   rQ   rK   rC   )r   r   r   r      s    zFP16SgdOptimizer.__init__Fc             C   s  d}t |j}|ddkr d}|r6d}|j}|j}nx|jd krRd}|j}|j}n\tjj|jkrv|j}|jtjj }n8tjj|jkr|jtjj }|j}ndstd	|j
|j}	| jdkrd S | jdkstd	| j| j||f| j | jd	| j\}
}|j|t |d
 dd}||t |d }| jj| t|	tjrRtd|dkr|j|	||
|g|	||g| j| j| jd n(|j|	||
|g|	||g| j| j| jd d S )Nr   Zspatbnrg   Tr   FzLUnrecognized parameter format to be updated by FP16 Optimizer. Parameter: {}z*Expect positive base learning rate, got {})rB   r;   Z_momentum_fp32g        )rS   rh   z)FP16Sgd does not support sparse gradients)r`   ra   rQ   )ro   ri   findr   r	   r   r   ZFLOAT16r$   r%   namer"   rB   rE   r;   rc   rV   r   r   r   rr   r#   rJ   ZFP16MomentumSGDUpdater`   ra   rQ   ZFP32MomentumSGDUpdate)r   r'   r(   r+   Zfp32_updateZfp32_update_flag
param_namer!   r   r"   r9   rx   Zmomentum_data_fp32r{   r   r   r   r&     sh    







zFP16SgdOptimizer._run)r   r   r8   r   r   N)F)r   rY   rZ   r    r&   r}   r   r   )r   r   r     s     r   c               @   s   e Zd Zdd Zdd ZdS )WeightDecayBuilderc             C   s
   || _ d S )N)rQ   )r   rQ   r   r   r   r      s    zWeightDecayBuilder.__init__c             C   s   t  }|d krttj}|jg d|j|j	dgdd}|jg d|j|j	dg| j
d}t|jtjrvtdn||j||j|g|j d S )Nz	ONE_{}_{}r   g      ?)rR   rS   zwd_{}_{}z2Weight decay does not yet support sparse gradients)r
   r2   r	   rp   r   rq   rV   r%   r4   r5   rQ   r#   r"   rJ   
ValueErrorrv   ri   )r   r'   r(   r+   ry   rz   ZWDr   r   r   r&     s"    zWeightDecayBuilder._runN)r   rY   rZ   r    r&   r   r   r   r   r     s   r   c            
       s.   e Zd Zd fdd		Zd
d Zdd Z  ZS )AdagradOptimizer{Gz?-C6?r   r8   NFr,   c                sT   t t|   || _|| _|| _|| _|| _|| _|| _	|| _
|	| _|
| _|| _d S )N)r_   r   r    alphaepsilondecayr;   rK   rowWiseenginerb   output_effective_lroutput_effective_lr_and_updaterc   )r   r   r   r   r;   rK   r   r   rb   r   r   rC   )r   r   r   r      s    zAdagradOptimizer.__init__c             C   s8  |j }|j}| jdkrd S |   | jd k	rt|tjs| jdksRtd	| j| 
|ddttjj\}}}|j|||||g| t|d | jdd}	t }
| j|	|
d k	ot|
jd | j||f| j| jd| j\}}| jrtd		| j t|g\}}t||krp| |t|d
 }|j!|gt|d dgdgd}|j"|t|d ddd}n(|j"g t|d |t| d gdd}ntd	| j | jt#krt|g\}}t||kst||t| }|j$g t|d d|d}n|j"|gt|d dd}| j%j&'| | jrLt|tjsLtdt|tjr| j(dksntd| )|| j*|}| jrd}nd}|+||||j,|j-|g||g| j.| jd nv||g}| j/r|'t|d  |'t|d  n| j0r|'t|d  |j1||||g|| j.t2| j(| jd d S )Nr   z'Lars offset must be nonnegative, got {}g        g      ?rd   )re   rf   )rG   )rB   r;   z#Using engine {} for rowWise AdagradZ_shapeZ_numrowsr   )ZstartsZendsZ_avg_squared_sum)Zinput_as_shaperS   )rR   rS   z#Using engine {} for regular AdagradZ_squared_sum)rS   rR   )rS   zIf SparseAdagrad with rowWise=True, gradient must be a gradientslice. PLease ensure that rowWise is not enabled for the dense Adagrad optimizer, as it is not supported.z?Decay is not implemented for SparseAdagrad and must be set to 1ZRowWiseSparseAdagradZSparseAdagrad)r   r   _effective_lr_update)r   r   r   )3ri   r"   r   rI   rb   r#   r	   rJ   r$   r%   rX   rj   rk   rl   rm   rn   r7   ro   r
   r2   rH   r3   r4   rE   r;   rc   r   loggerinfor   r   InferShapesAndTypesZShapeZSlicerV   FP16_ENGINESZFloat16ConstantFillr   r   rr   r   rL   rK   __getattr__rt   rs   r   r   r   ZAdagradfloat)r   r'   r(   r+   r!   r"   rW   rT   rU   rw   r6   r9   rx   shapestypesrR   Znum_rowsZparam_squared_sumopoutput_argsr   r   r   r&     s    









zAdagradOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP     s    z$AdagradOptimizer.scale_learning_rate)
r   r   r   r8   NFr,   NFF)r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r     s      vr   c            	       s.   e Zd Zd fdd		Zd
d Zdd Z  ZS )WngradOptimizer      ?&.>r8   Nr,         Y@Fc
                sN   t t|   || _|| _|| _|| _|| _|| _|| _	|| _
|	| _|
| _d S )N)r_   r   r    r   r   r;   rK   r   moment_initrb   r   r   rc   )r   r   r   r;   rK   r   r   rb   r   r   rC   )r   r   r   r      s    zWngradOptimizer.__init__c             C   s  |j }|j}| jdkrd S |   | jd k	rt|tjs| jdksRtd	| j| 
|ddttjj\}}}|j|||||g| t|d | jdd}	t }
| j|	|
d k	ot|
jd | j||f| j| jd| j\}}|jg t|d	 d
g| jd}| jj| t|tjr\| || j |}|j!|||j"|j#|g||g| j$| j%d nn||g}| j&r|t|d  |t|d  n| j'r|t|d  |j(||||g|| j$| j%d d S )Nr   z'Lars offset must be nonnegative, got {}g        g      ?rd   )re   rf   )rG   )rB   r;   _momentr   )rR   rS   )r   r   r   r   ))ri   r"   r   rI   rb   r#   r	   rJ   r$   r%   rX   rj   rk   rl   rm   rn   r7   ro   r
   r2   rH   r3   r4   rE   r;   rc   rV   r   r   r   rr   rL   rK   ZSparseWngradrt   rs   r   r   r   r   ZWngrad)r   r'   r(   r+   r!   r"   rW   rT   rU   rw   r6   r9   rx   momentr   r   r   r   r&     sb    



zWngradOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP     s    z#WngradOptimizer.scale_learning_rate)	r   r   r8   Nr,   r   NFF)r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r     s      Ar   c                   s.   e Zd Zd fdd	Zd	d
 Zdd Z  ZS )AdadeltaOptimizer{Gz?-C6?ffffff?r8   Nr,   c                s<   t t|   || _|| _|| _|| _|| _|| _|| _	dS )au  Constructor function to add Adadelta Optimizer

        Args:
            alpha: learning rate
            epsilon: attribute of Adadelta to avoid numerical issues
            decay: attribute of Adadelta to decay the squared gradient sum
            policy: specifies how learning rate should be applied, options are
              "fixed", "step", "exp", etc.
            sparse_dedup_aggregator: specifies deduplication strategy for
              gradient slices. Works while using sparse gradients. Options
              include "mean" and "sum".
            engine: the engine used, options include "", "CUDNN", etc.
        N)
r_   r   r    r   r   r   r;   rK   r   rc   )r   r   r   r   r;   rK   r   rC   )r   r   r   r      s    zAdadeltaOptimizer.__init__c       
      C   s
  |j }|j}| jdkrd S | j||f| j| jd| j\}}|j|gt|d dd}|j|gt|d dd}	| jj	
| | jj	
|	 t|tjr| || j|}|j|||	|j|j|g|||	g| j| j| jd n*|j|||	||g|||	g| j| j| jd d S )Nr   )rB   r;   Z_squared_momentg        )rS   Z_squared_moment_update)r   r   r   )ri   r"   r   rE   r;   rc   rV   ro   r   r   rr   r#   r	   rJ   rL   rK   ZSparseAdadeltart   rs   r   r   r   ZAdadelta)
r   r'   r(   r+   r!   r"   r9   rx   r   Zmoment_updater   r   r   r&     s<    

zAdadeltaOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP   %  s    z%AdadeltaOptimizer.scale_learning_rate)r   r   r   r8   Nr,   )r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r     s    *r   c                   s.   e Zd Zd fdd	Zdd	 Zd
d Z  ZS )FtrlOptimizer{Gz?-C6?r   Nr,   c                s6   t t|   || _|| _|| _|| _|| _|| _d S )N)	r_   r   r    r   betalambda1lambda2rK   r   )r   r   r   r   r   rK   r   )r   r   r   r    +  s    zFtrlOptimizer.__init__c          	   C   s   |j }|j}| jdkrd S |j|gt|d dgdd}| jj| t|t	j
r| || j|}|j|||j|jg||g| j| j| j| j| jd n,|j|||g||g| j| j| j| j| jd d S )Nr   Z_ftrl_nz   g        )extra_shaperS   )r   r   r   r   r   )ri   r"   r   rV   ro   r   r   rr   r#   r	   rJ   rL   rK   Z
SparseFtrlrt   rs   r   r   r   r   ZFtrl)r   r'   r(   r+   r!   r"   nzr   r   r   r&   5  s8    

zFtrlOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP   Y  s    z!FtrlOptimizer.scale_learning_rate)r   r   r   r   Nr,   )r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r   *  s    	$r   c                   s2   e Zd ZdZd fdd	Zd	d
 Zdd Z  ZS )GFtrlOptimizerzGroup Lasso FTRL Optimizer.{Gz?-C6?r   Nr,   c                s6   t t|   || _|| _|| _|| _|| _|| _d S )N)	r_   r   r    r   r   r   r   rK   r   )r   r   r   r   r   rK   r   )r   r   r   r    a  s    zGFtrlOptimizer.__init__c          	   C   sv   |j }|j}| jdkrd S |j|gt|d dgdd}| jj| |j|||g||g| j	| j| j
| j| jd d S )Nr   Z	_gftrl_nzr   g        )r   rS   )r   r   r   r   r   )ri   r"   r   rV   ro   r   r   rr   ZGFtrlr   r   r   r   )r   r'   r(   r+   r!   r"   r   r   r   r   r&   k  s$    

zGFtrlOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP     s    z"GFtrlOptimizer.scale_learning_rate)r   r   r   r   Nr,   )r   rY   rZ   __doc__r    r&   rP   r}   r   r   )r   r   r   ^  s
    	r   c                   s.   e Zd Zd fdd	Zdd Zdd Z  ZS )AdamOptimizerMbP??+?:0yE>r8   F{Gz?TNr,   c                sZ   t t|   || _|| _|| _|| _|| _|| _|| _	|| _
|	| _|
| _|| _|| _d S )N)r_   r   r    r   beta1beta2r   r;   use_lr_adaptionlr_alphanormalized_lr_adaptionrK   r   r   rc   )r   r   r   r   r   r;   r   r   r   rK   r   r   rC   )r   r   r   r      s    zAdamOptimizer.__init__c             C   s  |j }|j}| jdkrd S | j||f| j| jd| j\}}|j|g|d dd}| jrt	|g\}	}
|jg |d |	| d gdd}n|j|g|d dd}| j
j| | j
j| | j
j| | jrt|tjstd	|||g}| jrt|d
 }|| t|tjr| || j|}| jr:d}nd}||||||j|j||g|| j| j| jd | jr|j||j|g|g| j| jd nL|j||||||g|| j| j| jd | jr|j|||g|g| j| jd d S )Nr   )rB   r;   Z_first_momentg        )rS   Z_avg_second_moment)rR   rS   Z_second_momentzIf SparseAdam with rowWise=True, gradient must be a gradientslice. PLease ensure that rowWise is not enabled for the dense Adam optimizer, as it is not supported.Z_effective_gradZRowWiseSparseAdamZ
SparseAdam)r   r   r   )r   r   ) ri   r"   r   rE   r;   rc   rV   r   r   r   r   r   rr   r   r#   r	   rJ   r$   r   ro   rL   rK   r   rt   rs   r   r   r   ZLearningRateAdaptionr   r   ZAdam)r   r'   r(   r+   r!   r"   r9   rD   m1r   r   m2Zoutput_blobsZeffective_gradr   r   r   r   r&     s~    






zAdamOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP     s    z!AdamOptimizer.scale_learning_rate)r   r   r   r   r8   Fr   TNFr,   )r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r     s      Tr   c                   s2   e Zd ZdZd fd
d	Zdd Zdd Z  ZS )YellowFinOptimizerzYellowFin: An automatic tuner for momentum SGD

    See https://arxiv.org/abs/1706.03471 for more details. This implementation
    has separate learning rate and momentum per each parameter.皙?        +?   Tư>r8   Nc	       
         sH   t t|   || _|| _|| _|| _|| _|| _|| _	|| _
|	| _d S )N)r_   r   r    r   mur   curv_win_widthzero_debiasr   r;   rK   rc   )
r   r   r   r   r   r   r   r;   rK   rC   )r   r   r   r      s    
zYellowFinOptimizer.__init__c             C   s  d}|j }|j}|j|g|d dd}|jg |d | jgdd}|j|g|d dd}	|j|g|d dd}
|jg |d	 d
g| jd}|jg |d d
g| jd}|jg |d |gdd}| jdkstt|tj	rtdt
j||dd}| jj| | jj| | jj| | jj| | jj| | jj|	 | jj|
 | jj| ||||||	|
|g}|j|||g || j| j| j| jd d S )N   r   g        )rS   Z	_curv_win)rR   rS   Z_g_avgZ_g2_avgZ_lr_avgr   Z_mu_avgZ_scalars_memoryr   z+YellowFin does not support sparse gradients)r:   )r   r   r   r   )ri   r"   rV   r   r   r   r$   r#   r	   rJ   r   r?   r   r   rr   r   Z	YellowFinr   r   r   )r   r'   r(   r+   ZSCALARS_MEMORY_SIZEr!   r"   r   Zcurv_winZg_avgZg2_avgZlr_avgZmu_avgZscalars_memoryrD   Zyf_in_out_argsr   r   r   r&     s    


zYellowFinOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP   f  s    z&YellowFinOptimizer.scale_learning_rate)r   r   r   r   Tr   r8   N)r   rY   rZ   r   r    r&   rP   r}   r   r   )r   r   r     s          Wr   c                   s.   e Zd Zd fdd	Zd	d
 Zdd Z  ZS )RmsPropOptimizer{Gz??        h㈵>r8   r,   c                s<   t t|   || _|| _|| _|| _|| _|| _|| _	d S )N)
r_   r   r    r   r   r`   r   r;   r   rc   )r   r   r   r`   r   r;   r   rC   )r   r   r   r    l  s    
zRmsPropOptimizer.__init__c             C   sD  |j }|j}| jdkstt|tjr.tdt }|d krJt	t
j}|jg d|j|jdgdd}| j||f| j | jd| j\}}	|j|gt|d d	d
}
|j|gt|d d	d
}|j|gt|d d	d
}| jj| | jj| |j||||g|
||g| j| j| j| jd ||
|||g|
||g d S )Nr   z1RmsPropOptimizer doesn't support sparse gradientsz	ONE_{}_{}r   g      ?)rR   rS   )rB   r;   Z_grad_og        )rs   Z_mean_squaresrh   )r   r`   r   r   )ri   r"   r   r$   r#   r	   rJ   r
   r2   rp   r   rq   rV   r%   r4   r5   rE   r;   rc   ro   r   r   rr   ZRmsPropr   r`   r   r   ru   )r   r'   r(   r+   r!   r"   ry   rz   r9   rx   Zgrad_omsZmomr   r   r   r&     sV    





zRmsPropOptimizer._runc             C   s   |  j |9  _ d S )N)r   )r   r|   r   r   r   rP     s    z$RmsPropOptimizer.scale_learning_rate)r   r   r   r   r8   r,   )r   rY   rZ   r    r&   rP   r}   r   r   )r   r   r   k  s        >r   c             C   s"   t | j}|t | j |S )N)r	   ZInferBlobDevicesr'   updater(   )modelparam_to_devicer   r   r   _get_param_to_device  s    r   c             C   s   |}|p
i }| |kr||  }ndt |tjrj|}t|j|krL|t|j }qt|j|kr|t|j }nt|}||kr|| }|d k	std| |S )Nz,Cannot infer device for {}: no op creates it)r#   r	   rJ   ro   rs   rt   r$   r%   )r   r"   r   Zdefault_devicedeviceZ	grad_namer   r   r   get_param_device  s     

r   c               C   s
   t tS )z
    Gets current value for lr_injection, a multiplier for all base
    learning rates.
    Must set allow_lr_injection=True when building optimizer, as it
    relies on synchronization over CPU.
    )r   Z	FetchBlob_LEARNING_RATE_INJECTIONr   r   r   r   get_lr_injection  s    r   c             C   s"   t ttjt| gtjd dS )z
    Sets lr_injection, a multiplier for all base learning rates.
    Must set allow_lr_injection=True when building optimizer, as it
    relies on synchronization over CPU.
    )ZdtypeN)r   ZFeedBlobr   rj   arrayr   rl   )Zlr_injection_valuer   r   r   set_lr_injection  s
    r   c             C   s$  t | g }xt|D ]|\}}tt|j|j|}t |P t|jt j	sX|jn|jj
}	d|}
| j|	|
}| j|}|| W d Q R X qW t t tjb | j|d}| jj|ddd}| jjg dg t|d}| j||gd}| j||gd	}|S Q R X W d Q R X d S )
Nzgrad_{}_squared_sumgrad_squared_full_sumglobal_normg      ?)exponent	clip_norm)rR   rS   max_norm
norm_ratio)r	   Z	NameScope	enumerater   ro   ri   r"   DeviceScoper#   rJ   rs   r%   r'   ZSumSqrElementsZEnsureCPUOutputrr   rp   r   rq   ZSumZPowr(   rV   r   ZMaxZDiv)r   paramsZ
name_scoper   max_gradient_normZgrad_squared_sumsir!   r   r"   Zgrad_squared_sum_nameZgrad_squared_sumZgrad_squared_sum_cpur   r   r   r   r   r   r   r   _calc_norm_ratio  sJ    
r   FTc          
   C   s*  t | }|   g }x,|  D ] }|r4|j| jkr4q|| qW d }	|d k	r^t| |d||}	|r| jt	s| j
jg t	dgdd}
nt	}
|	d kr|
}	n| jj|	|
gddd}	||	 xl|D ]d}t|j}t||j|}t|8 |jr|r|| j| j
| n|| j| j
| W d Q R X qW |S )NZnorm_clipped_grad_updater   g      ?)rR   rS   r<   )r=   )r   ZValidateZGetOptimizationParamInfori   weightsrr   r   r'   r@   r   r(   rV   rA   rF   ro   r   r"   r	   r   	optimizer)r   r   weights_onlyuse_param_info_optimr   allow_lr_injectionr   r   r+   r<   r   r   r   r   r   r   _build4  sN    



r   c             C   s   t | t|dddd dS )zAdds a decay to weights in the model.

    This is a form of L2 regularization.

    Args:
        weight_decay: strength of the regularization
    )rQ   TF)r   r   N)r   r   )r   rQ   r   r   r   add_weight_decayr  s
    r   c             K   s   t |f|}t| |||dS )N)r   r   )r\   r   )r   rB   r   r   rC   Zsgd_optimizerr   r   r   	build_sgd  s    r   c             K   s   t |f|}t| |||dS )N)r   r   )r~   r   )r   rB   r   r   rC   Zmulti_prec_sgd_optimizerr   r   r   build_multi_precision_sgd  s    
r   c             K   s   t |f|}t| |S )N)r   r   )r   rB   rC   Zfp16_sgd_optimizerr   r   r   build_fp16_sgd  s    
r   SIMDc             K   s@   |dkr$t dstt ds$ttf d|i|}t| |S )Nr   ZFtrl_ENGINE_SIMDZSparseFtrl_ENGINE_SIMDr   )r	   
IsOperatorr$   r   r   )r   r   rC   Zftrl_optimizerr   r   r   
build_ftrl  s
    r   r,   c             K   s2   |dkrt dsttf d|i|}t| |S )Nr   ZGFtrl_ENGINE_SIMDr   )r	   r   r$   r   r   )r   r   rC   Zgftrl_optimizerr   r   r   build_gftrl  s    r   c             K   s"   t f d|i|}t| |||dS )Nr   )r   r   )r   r   )r   rB   
parametersr   r   rC   Zadagrad_optimizerr   r   r   build_adagrad  s    r   c             K   s"   t f d|i|}t| |||dS )Nr   )r   r   )r   r   )r   rB   r   r   r   rC   Zwngrad_optimizerr   r   r   build_wngrad  s    r   c             K   s"   t f d|i|}t| |||dS )Nr   )r   r   )r   r   )r   rB   r   r   r   rC   Zadadelta_optimizerr   r   r   build_adadelta  s    r  c             K   s"   t f d|i|}t| |||dS )Nr   )r   r   )r   r   )r   rB   r   r   rC   Zadam_optimizerr   r   r   
build_adam  s    r  皙?c             K   s   t f d|i|}t| |S )Nr   )r   r   )r   rB   rC   Zyellowfin_optimizerr   r   r   build_yellowfin  s    r  c             K   s"   t f d|i|}t| |||dS )Nr   )r   r   )r   r   )r   rB   r   r   rC   Zrms_prop_optimizerr   r   r   build_rms_prop  s    r  )NN)FTNF)NF)NF)r   )r,   )NNF)NNF)NNF)NF)r  )NF)>
__future__r   r   r   r   collectionsr   r   Zpast.builtinsr   loggingZnumpyrj   Zcaffe2.pythonr	   r
   r   r   Zcaffe2.python.modelingr   Zcaffe2.protor   r   r   intr   r   	getLoggerr   r   objectr   r\   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r   r   r   r   <module>   sz   

 Ai9` WH4*lwW

;   
8 
 



  
  
  
 


 