import platform
import matplotlib.pyplot as plt
from os.path import expanduser, join
import dipy as dp
from dipy.tracking import utils
from dipy.data import default_sphere
from dipy.io import read_bvals_bvecs
from dipy.reconst.shm import CsaOdfModel
from dipy.direction import peaks_from_model
from dipy.viz import window, actor, has_fury,colormap
from dipy.core.gradients import gradient_table
from dipy.reconst.csdeconv import auto_response_ssst
from dipy.io.image import load_nifti,save_nifti,load_nifti_data # 保存和读取nifti
from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
from dipy.data import fetch_stanford_hardi,get_fnames
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.streamline import Streamlines
from dipy.io.stateful_tractogram import Space,StatefulTractogram
from dipy.io.streamline import save_trk
print("DIPY VERSION:",dp.__version__)
print("PYTHON VERSION:",platform.python_version())
/usr/local/lib/python3.8/site-packages/setuptools/distutils_patch.py:25: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first. warnings.warn(
DIPY VERSION: 1.5.0 PYTHON VERSION: 3.8.5
# 下载HARDI数据集 | 备注,仅使用jupyter演示,请使用terminal方式(dipy_fetch stanford_hardi)下载。
fetch_stanford_hardi()
({'HARDI150.nii.gz': ('https://stacks.stanford.edu/file/druid:yx282xq2090/dwi.nii.gz', '0b18513b46132b4d1051ed3364f2acbc'), 'HARDI150.bval': ('https://stacks.stanford.edu/file/druid:yx282xq2090/dwi.bvals', '4e08ee9e2b1d2ec3fddb68c70ae23c36'), 'HARDI150.bvec': ('https://stacks.stanford.edu/file/druid:yx282xq2090/dwi.bvecs', '4c63a586f29afc6a48a5809524a76cb4')}, '/Users/reallo/.dipy/stanford_hardi')
fname,fbval,fbvec = get_fnames("stanford_hardi")
fname,fbval,fbvec
('/Users/reallo/.dipy/stanford_hardi/HARDI150.nii.gz', '/Users/reallo/.dipy/stanford_hardi/HARDI150.bval', '/Users/reallo/.dipy/stanford_hardi/HARDI150.bvec')
label_fname = get_fnames("stanford_labels")
print(label_fname)
/Users/reallo/.dipy/stanford_hardi/aparc-reduced.nii.gz
data,affine,img = load_nifti(fname,return_img=True)
labels = load_nifti_data(label_fname)
print("dMRI data shape:",data.shape)
print("dMRI labels shape:",labels.shape)
dMRI data shape: (81, 106, 76, 160) dMRI labels shape: (81, 106, 76)
bvals,bvecs = read_bvals_bvecs(fbval,fbvec)
gtab = gradient_table(bvals,bvecs)
本数据集中提供了所有白质的label map(1 or 2),我们可以创建一个白质的mask,来限制白质追踪的范围。
white_matter = (labels == 1) | (labels ==2)
white_matter.shape
(81, 106, 76)
white_matter[:,:,30]
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
plt.figure()
plt.subplot(1,2,1)
plt.imshow(white_matter[:,:,30].T,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(white_matter[:,60,:].T,cmap='gray')
<matplotlib.image.AxesImage at 0x1257a23a0>
plt.close()
# Automatic estimation of single-shell single-tissue (ssst) response
# function using FA.
# https://dipy.org/documentation/1.5.0/reference/dipy.reconst/#auto-response-ssst
reponse,ratio = auto_response_ssst(gtab,data,roi_radii=10,fa_thr=0.7)
csa_model = CsaOdfModel(gtab,sh_order=6)
# https://dipy.org/documentation/1.5.0/reference/dipy.direction/#id27
# Returns:
# pamPeaksAndMetrics
# An object with gfa, peak_directions, peak_values, peak_indices, odf, shm_coeffs as attributes
csa_peaks = peaks_from_model(csa_model,data,default_sphere,relative_peak_threshold=.8,min_separation_angle=45,mask=white_matter)
print(csa_peaks.peak_dirs.shape)
print(csa_peaks.peak_values.shape)
(81, 106, 76, 5, 3) (81, 106, 76, 5)
# 提供默认的三角球体,repulsion724
# https://github.com/dipy/dipy/blob/master/dipy/data/__init__.py#L158:5
default_sphere
<dipy.core.sphere.HemiSphere at 0x1104263a0>
#基于fury库可视化direction field
scene = window.Scene()
scene.add(actor.peak_slicer(csa_peaks.peak_dirs,
csa_peaks.peak_values,
colors=None))
# window.record(scene,out_path='csa_direction_field.png',size=(900,900))
# interactive
window.show(scene,size=(800,800))
print(csa_peaks.gfa.shape)
print(type(csa_peaks.gfa))
(81, 106, 76) <class 'numpy.ndarray'>
# 查看一下csa_peaks中的GFA
plt.figure()
plt.imshow(csa_peaks.gfa[:,:,30].T,cmap='gray')
plt.title("GFA, generalized fractional anisotropy ")
Text(0.5, 1.0, 'GFA, generalized fractional anisotropy ')
plt.imshow((csa_peaks.gfa[:, :, 30] > 0.25).T, cmap='gray', origin='lower')
<matplotlib.image.AxesImage at 0x1283fcf40>
plt.close()
csa_peaks.gfa[:,:,30]
array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]])
### 步骤二, 设置纤维束追踪合适停止的方法
stopping_criterion = ThresholdStoppingCriterion(csa_peaks.gfa,.25) # 输入stop_map,stop_threshold
stopping_criterion
<dipy.tracking.stopping_criterion.ThresholdStoppingCriterion at 0x12789d5a0>
### 步骤三,设置追踪种子(seed)
seed_mask = (labels == 2 )
seeds = utils.seeds_from_mask(seed_mask,affine,density=[2,2,2])
seed_mask
array([[[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]], [[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]], [[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]], ..., [[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]], [[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]], [[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]]])
# Initialization of LocalTracking. The computation happens in the next step.
stream_generator = LocalTracking(csa_peaks,
stopping_criterion,
seeds,
affine=affine,
step_size=.5)
# Generate streamlines objects
streamlines = Streamlines(stream_generator)
color = colormap.line_colors(streamlines)
streamlines_actor = actor.line(streamlines,colormap.line_colors(streamlines))
# create 3D display
scene = window.Scene()
scene.add(streamlines_actor)
# save figure
window.record(scene)
window.show(scene)
sft = StatefulTractogram(streamlines,img,Space.RASMM)
streamlines
ArraySequence([array([[ -4.5, -44.5, 7.5]]), array([[ -4.5, -44.5, 7.5]]), array([[ -4.5, -44.5, 7.5]]), ..., array([[-0.5, 32.5, 12.5]]), array([[ 0.5, 32.5, 12.5]]), array([[ 0.5, 32.5, 12.5]])])
save_trk(sft, "tractogram_EuDX.trk", streamlines)